diff --git a/.buildkite/run-benchmarks.sh b/.buildkite/run-benchmarks.sh index 6283cd1064017..75e9cf6a6579a 100644 --- a/.buildkite/run-benchmarks.sh +++ b/.buildkite/run-benchmarks.sh @@ -54,12 +54,12 @@ tail -n 24 benchmark_serving.txt >> benchmark_results.md # last 24 lines echo '```' >> benchmark_results.md # if the agent binary is not found, skip uploading the results, exit 0 -if [ ! -f /workspace/buildkite-agent ]; then +if [ ! -f buildkite-agent ]; then exit 0 fi # upload the results to buildkite -/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md +buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md # exit with the exit code of the benchmarks if [ $bench_latency_exit_code -ne 0 ]; then @@ -75,4 +75,4 @@ if [ $bench_serving_exit_code -ne 0 ]; then fi rm ShareGPT_V3_unfiltered_cleaned_split.json -/workspace/buildkite-agent artifact upload "*.json" +buildkite-agent artifact upload "*.json" diff --git a/.buildkite/test-template-aws.j2 b/.buildkite/test-template-aws.j2 index 9f7d07acca298..3b5d36b246673 100644 --- a/.buildkite/test-template-aws.j2 +++ b/.buildkite/test-template-aws.j2 @@ -22,7 +22,9 @@ steps: {% for step in steps %} - label: "{{ step.label }}" agents: - {% if step.no_gpu %} + {% if step.label == "Documentation Build" %} + queue: small_cpu_queue + {% elif step.no_gpu %} queue: cpu_queue {% elif step.num_gpus == 2 or step.num_gpus == 4 %} queue: gpu_4_queue @@ -47,6 +49,9 @@ steps: {% if not step.no_gpu %} gpus: all {% endif %} + {% if step.label == "Benchmarks" %} + mount-buildkite-agent: true + {% endif %} command: ["bash", "-c", "cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}"] environment: - VLLM_USAGE_SOURCE=ci-test diff --git a/CMakeLists.txt b/CMakeLists.txt index a197063f33601..ad6736c47f459 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,19 +66,6 @@ endif() # find_package(Torch REQUIRED) -# -# Normally `torch.utils.cpp_extension.CUDAExtension` would add -# `libtorch_python.so` for linking against an extension. Torch's cmake -# configuration does not include this library (presumably since the cmake -# config is used for standalone C++ binaries that link against torch). -# The `libtorch_python.so` library defines some of the glue code between -# torch/python via pybind and is required by VLLM extensions for this -# reason. So, add it by manually with `find_library` using torch's -# installed library path. -# -find_library(torch_python_LIBRARY torch_python PATHS - "${TORCH_INSTALL_PREFIX}/lib") - # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -171,7 +158,7 @@ set(VLLM_EXT_SRC "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" "csrc/moe_align_block_size_kernels.cu" - "csrc/pybind.cpp") + "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") include(FetchContent) @@ -218,6 +205,7 @@ define_gpu_extension_target( COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} + USE_SABI 3 WITH_SOABI) # @@ -225,7 +213,7 @@ define_gpu_extension_target( # set(VLLM_MOE_EXT_SRC - "csrc/moe/moe_ops.cpp" + "csrc/moe/torch_bindings.cpp" "csrc/moe/topk_softmax_kernels.cu") define_gpu_extension_target( @@ -235,6 +223,7 @@ define_gpu_extension_target( SOURCES ${VLLM_MOE_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 WITH_SOABI) # @@ -249,7 +238,7 @@ set(VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" "csrc/punica/punica_ops.cu" - "csrc/punica/punica_pybind.cpp") + "csrc/punica/torch_bindings.cpp") # # Copy GPU compilation flags+update for punica @@ -286,6 +275,7 @@ if (VLLM_PUNICA_GPU_ARCHES) SOURCES ${VLLM_PUNICA_EXT_SRC} COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} + USE_SABI 3 WITH_SOABI) else() message(WARNING "Unable to create _punica_C target because none of the " diff --git a/Dockerfile.rocm b/Dockerfile.rocm index e30a2aaf30209..954958df88fc0 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -106,9 +106,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip install -U -r requirements-rocm.txt \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \ && python3 setup.py install \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.cpython-39-x86_64-linux-gnu.so vllm/ \ - && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.cpython-39-x86_64-linux-gnu.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_C.abi3.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_punica_C.abi3.so vllm/ \ + && cp build/lib.linux-x86_64-cpython-39/vllm/_moe_C.abi3.so vllm/ \ && cd .. diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index f69d91a086a9f..1a41b66b38824 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -36,7 +36,8 @@ def main(args: argparse.Namespace): enable_chunked_prefill=args.enable_chunked_prefill, download_dir=args.download_dir, block_size=args.block_size, - gpu_memory_utilization=args.gpu_memory_utilization) + gpu_memory_utilization=args.gpu_memory_utilization, + distributed_executor_backend=args.distributed_executor_backend) sampling_params = SamplingParams( n=args.n, @@ -221,5 +222,12 @@ def run_to_completion(profile_dir: Optional[str] = None): help='the fraction of GPU memory to be used for ' 'the model executor, which can range from 0 to 1.' 'If unspecified, will use the default value of 0.9.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=None, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 7c8cb5ee8cea2..90f7433e0ae28 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -78,6 +78,7 @@ def run_vllm( enable_prefix_caching: bool, enable_chunked_prefill: bool, max_num_batched_tokens: int, + distributed_executor_backend: Optional[str], gpu_memory_utilization: float = 0.9, download_dir: Optional[str] = None, ) -> float: @@ -100,6 +101,7 @@ def run_vllm( download_dir=download_dir, enable_chunked_prefill=enable_chunked_prefill, max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, ) # Add the requests to the engine. @@ -225,8 +227,8 @@ def main(args: argparse.Namespace): args.enforce_eager, args.kv_cache_dtype, args.quantization_param_path, args.device, args.enable_prefix_caching, args.enable_chunked_prefill, - args.max_num_batched_tokens, args.gpu_memory_utilization, - args.download_dir) + args.max_num_batched_tokens, args.distributed_executor_backend, + args.gpu_memory_utilization, args.download_dir) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -368,6 +370,13 @@ def main(args: argparse.Namespace): type=str, default=None, help='Path to save the throughput results in JSON format.') + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=None, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 0cf37769a6960..61d4843838ba0 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -12,7 +12,7 @@ include_directories("${CMAKE_SOURCE_DIR}/csrc") # # Check the compile flags # -list(APPEND CXX_COMPILE_FLAGS +list(APPEND CXX_COMPILE_FLAGS "-fopenmp" "-DVLLM_CPU_EXTENSION") @@ -44,8 +44,8 @@ if (AVX512_FOUND) find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) - if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND - CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") else() message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") @@ -73,7 +73,7 @@ set(VLLM_EXT_SRC "csrc/cpu/cache.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/pybind.cpp") + "csrc/cpu/torch_bindings.cpp") define_gpu_extension_target( _C @@ -81,10 +81,10 @@ define_gpu_extension_target( LANGUAGE CXX SOURCES ${VLLM_EXT_SRC} COMPILE_FLAGS ${CXX_COMPILE_FLAGS} - WITH_SOABI + USE_SABI 3 + WITH_SOABI ) add_custom_target(default) message(STATUS "Enabling C extension.") add_dependencies(default _C) - diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 00c81e4d00ad8..f3c1286dd8498 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -5,7 +5,7 @@ macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) file(REAL_PATH ${EXECUTABLE} EXECUTABLE) set(Python_EXECUTABLE ${EXECUTABLE}) - find_package(Python COMPONENTS Interpreter Development.Module) + find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule) if (NOT Python_FOUND) message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") endif() @@ -294,6 +294,7 @@ endmacro() # INCLUDE_DIRECTORIES - Extra include directories. # LIBRARIES - Extra link libraries. # WITH_SOABI - Generate library with python SOABI suffix name. +# USE_SABI - Use python stable api # # Note: optimization level/debug info is set via cmake build type. # @@ -301,7 +302,7 @@ function (define_gpu_extension_target GPU_MOD_NAME) cmake_parse_arguments(PARSE_ARGV 1 GPU "WITH_SOABI" - "DESTINATION;LANGUAGE" + "DESTINATION;LANGUAGE;USE_SABI" "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") # Add hipify preprocessing step when building with HIP/ROCm. @@ -315,7 +316,11 @@ function (define_gpu_extension_target GPU_MOD_NAME) set(GPU_WITH_SOABI) endif() - Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI}) + if (GPU_USE_SABI) + Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}") + else() + Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}") + endif() if (GPU_LANGUAGE STREQUAL "HIP") # Make this target dependent on the hipify preprocessor step. diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 867f63f12de4b..86ac2e75e78ee 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,5 +1,5 @@ #include -#include +#include #include #include diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 8f89f89786c3b..91083481705cb 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -17,7 +17,7 @@ * limitations under the License. */ -#include +#include #include #include #include @@ -808,16 +808,17 @@ void paged_attention_v1( torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] - int block_size, int max_seq_len, + int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, @@ -972,16 +973,17 @@ void paged_attention_v2( torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - int num_kv_heads, // [num_heads] - float scale, + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] - int block_size, int max_seq_len, + int64_t block_size, int64_t max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) @@ -990,4 +992,4 @@ void paged_attention_v2( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP \ No newline at end of file +#undef DIVIDE_ROUND_UP diff --git a/csrc/cache.h b/csrc/cache.h index 435ae3e57f555..86caa9345361d 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -8,14 +8,18 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, const torch::Tensor& block_mapping); -void copy_blocks(std::vector& key_caches, - std::vector& value_caches, +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& block_mapping); void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, const float kv_scale); + const std::string& kv_cache_dtype, + const double kv_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, @@ -25,4 +29,4 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const float scale, const std::string& kv_cache_dtype); + const double scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index d924ac39b89ca..72041076ae009 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -95,8 +95,11 @@ __global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, } // namespace vllm -void copy_blocks(std::vector& key_caches, - std::vector& value_caches, +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); @@ -255,7 +258,7 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const float kv_scale) { + const std::string& kv_cache_dtype, const double kv_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -334,7 +337,7 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, // Only for testing. void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, - const float kv_scale, const std::string& kv_cache_dtype) { + const double kv_scale, const std::string& kv_cache_dtype) { torch::Device src_device = src_cache.device(); torch::Device dst_device = dst_cache.device(); TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index ed8cfbd421f0f..8367093325314 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -420,12 +420,13 @@ void paged_attention_v1_impl_launcher( void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); @@ -738,12 +739,13 @@ void paged_attention_v2_impl_launcher( void paged_attention_v2( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { TORCH_CHECK(kv_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 2890ba6e2bb32..2b5c3bd6ee70b 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -5,8 +5,8 @@ namespace { template -void copy_blocks_cpu_impl(std::vector& key_caches, - std::vector& value_caches, +void copy_blocks_cpu_impl(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& mapping_pairs, const int element_num_per_block, const int layer_num) { @@ -82,8 +82,11 @@ void reshape_and_cache_cpu_impl( } }; // namespace -void copy_blocks(std::vector& key_caches, - std::vector& value_caches, +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, const torch::Tensor& block_mapping) { unsigned num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); @@ -104,7 +107,7 @@ void copy_blocks(std::vector& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, float kv_scale) { + const std::string& kv_cache_dtype, double kv_scale) { TORCH_CHECK(kv_scale == 1.0f); int num_tokens = key.size(0); diff --git a/csrc/cpu/cpu_types.hpp b/csrc/cpu/cpu_types.hpp index c1d3ec058b991..034c406a532d5 100644 --- a/csrc/cpu/cpu_types.hpp +++ b/csrc/cpu/cpu_types.hpp @@ -3,7 +3,7 @@ #define CPU_TYPES_HPP #include -#include +#include namespace vec_op { diff --git a/csrc/cpu/layernorm.cpp b/csrc/cpu/layernorm.cpp index 65d3ddcec5709..a76ad08928a2c 100644 --- a/csrc/cpu/layernorm.cpp +++ b/csrc/cpu/layernorm.cpp @@ -88,7 +88,7 @@ void fused_add_rms_norm_impl(scalar_t* __restrict__ input, } // namespace void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, - float epsilon) { + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -102,7 +102,7 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, } void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, - torch::Tensor& weight, float epsilon) { + torch::Tensor& weight, double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index e8aead17ae5a7..96bce7dda0132 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -168,7 +168,7 @@ void rotary_embedding_gptj_impl( }; // namespace void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int head_size, + torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = query.numel() / query.size(-1); int rot_dim = cos_sin_cache.size(1); diff --git a/csrc/cpu/pybind.cpp b/csrc/cpu/pybind.cpp deleted file mode 100644 index 63082393c8102..0000000000000 --- a/csrc/cpu/pybind.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include "cache.h" -#include "cuda_utils.h" -#include "ops.h" -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // vLLM custom ops - pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); - - // Attention ops - ops.def("paged_attention_v1", &paged_attention_v1, - "Compute the attention between an input query and the cached " - "keys/values using PagedAttention."); - ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); - - // Activation ops - ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); - ops.def("gelu_and_mul", &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); - ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); - - // Layernorm - ops.def("rms_norm", &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); - - ops.def("fused_add_rms_norm", &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); - - // Rotary embedding - ops.def("rotary_embedding", &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); - - // Cache ops - pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def("swap_blocks", &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def("copy_blocks", ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def("reshape_and_cache", &reshape_and_cache, - "Reshape the key and value tensors and cache them"); -} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp new file mode 100644 index 0000000000000..a2bf0d49adba5 --- /dev/null +++ b/csrc/cpu/torch_bindings.cpp @@ -0,0 +1,106 @@ +#include "cache.h" +#include "ops.h" +#include "registration.h" + +#include + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + // vLLM custom ops + + // Attention ops + // Compute the attention between an input query and the cached keys/values + // using PagedAttention. + ops.def( + "paged_attention_v1(" + " Tensor! out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); + + // PagedAttention V2. + ops.def( + "paged_attention_v2(" + " Tensor! out, Tensor exp_sums, Tensor max_logits," + " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2); + + // Activation ops + + // Activation function used in SwiGLU. + ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("silu_and_mul", torch::kCPU, &silu_and_mul); + + // Activation function used in GeGLU with `none` approximation. + ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_and_mul", torch::kCPU, &gelu_and_mul); + + // Activation function used in GeGLU with `tanh` approximation. + ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_tanh_and_mul", torch::kCPU, &gelu_tanh_and_mul); + + // GELU implementation used in GPT-2. + ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_new", torch::kCPU, &gelu_new); + + // Approximate GELU implementation. + ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_fast", torch::kCPU, &gelu_fast); + + // Layernorm + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def( + "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " + "()"); + ops.impl("rms_norm", torch::kCPU, &rms_norm); + + // In-place fused Add and RMS Normalization. + ops.def( + "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " + "float epsilon) -> ()"); + ops.impl("fused_add_rms_norm", torch::kCPU, &fused_add_rms_norm); + + // Rotary embedding + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. + ops.def( + "rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); +} + +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { + // Cache ops + // Swap in (out) the cache blocks from src to dst. + cache_ops.def( + "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); + cache_ops.impl("swap_blocks", torch::kCPU, &swap_blocks); + + // Copy the cache blocks from src to dst. + cache_ops.def( + "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " + "block_mapping) -> ()"); + cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks); + + // Reshape the key and value tensors and cache them. + cache_ops.def( + "reshape_and_cache(Tensor key, Tensor value," + " Tensor! key_cache, Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float kv_scale) -> ()"); + cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 2ba49b339e148..73944f4c14890 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,7 +1,5 @@ #pragma once -#include +int64_t get_device_attribute(int64_t attribute, int64_t device_id); -int get_device_attribute(int attribute, int device_id); - -int get_max_shared_memory_per_block_device_attribute(int device_id); +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index 7d8e2e19720fa..d6f9eb646fad5 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -2,7 +2,7 @@ #include #include #endif -int get_device_attribute(int attribute, int device_id) { +int64_t get_device_attribute(int64_t attribute, int64_t device_id) { int device, value; if (device_id < 0) { cudaGetDevice(&device); @@ -14,8 +14,8 @@ int get_device_attribute(int attribute, int device_id) { return value; } -int get_max_shared_memory_per_block_device_attribute(int device_id) { - int attribute; +int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id) { + int64_t attribute; // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 0b1d95848525a..82a3563979f16 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -1,17 +1,17 @@ #include #include #include -#include +#include #include "custom_all_reduce.cuh" -// fake pointer type -using fptr_t = uint64_t; +// fake pointer type, must match fptr_t type in ops.h +using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, - const std::vector& offsets, int rank, + const std::vector& offsets, int64_t rank, bool full_nvlink) { int world_size = offsets.size(); if (world_size > 8) @@ -55,7 +55,7 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, +bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, bool full_nvlink) { auto inp_size = inp.numel() * inp.element_size(); // custom allreduce requires input byte size to be multiples of 16 @@ -125,7 +125,7 @@ void dispose(fptr_t _fa) { delete fa; } -int meta_size() { return sizeof(vllm::Signal); } +int64_t meta_size() { return sizeof(vllm::Signal); } void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, @@ -134,10 +134,16 @@ void register_buffer(fptr_t _fa, torch::Tensor& t, fa->register_buffer(handles, offsets, t.data_ptr()); } -std::pair, std::vector> get_graph_buffer_ipc_meta( +std::tuple> get_graph_buffer_ipc_meta( fptr_t _fa) { auto fa = reinterpret_cast(_fa); - return fa->get_graph_buffer_ipc_meta(); + auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto handles = + torch::empty({static_cast(handle_bytes.size())}, options); + std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); + return {handles, std::move(offsets)}; } void register_graph_buffers(fptr_t _fa, const std::vector& handles, diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 3ecea03242f06..a634e1c3d4886 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -4,7 +4,7 @@ */ #pragma once -#include +#include #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 70a2b3b0a07b1..ca1c04bd880d9 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -291,7 +291,7 @@ fused_add_rms_norm_kernel( void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] - float epsilon) { + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; @@ -319,7 +319,7 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] - float epsilon) { + double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/moe/moe_ops.cpp b/csrc/moe/moe_ops.cpp deleted file mode 100644 index 4122f7630d7c7..0000000000000 --- a/csrc/moe/moe_ops.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include "moe_ops.h" - -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("topk_softmax", &topk_softmax, - "Apply topk softmax to the gating outputs."); -} diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 93e7844ac1993..a251730aa765a 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -1,6 +1,6 @@ #pragma once -#include +#include void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, torch::Tensor& token_expert_indices, diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 6ba4fcdb3a3f2..de9747b602524 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -16,7 +16,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include #include #include #include "../cuda_compat.h" diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp new file mode 100644 index 0000000000000..243752b9a9e8c --- /dev/null +++ b/csrc/moe/torch_bindings.cpp @@ -0,0 +1,12 @@ +#include "registration.h" +#include "moe_ops.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + // Apply topk softmax to the gating outputs. + m.def( + "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " + "token_expert_indices, Tensor gating_output) -> ()"); + m.impl("topk_softmax", torch::kCUDA, &topk_softmax); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/moe_align_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu index edc441d121029..1f8d75da83bb8 100644 --- a/csrc/moe_align_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -108,8 +108,8 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, } } // namespace vllm -void moe_align_block_size(torch::Tensor topk_ids, int num_experts, - int block_size, torch::Tensor sorted_token_ids, +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); diff --git a/csrc/ops.h b/csrc/ops.h index 06b60e748886f..0c270a78c331f 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,40 +1,42 @@ #pragma once -#include +#include void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step); + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); void paged_attention_v2( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int num_kv_heads, float scale, - torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size, - int max_seq_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, float kv_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step); + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, - float epsilon); + double epsilon); void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, - torch::Tensor& weight, float epsilon); + torch::Tensor& weight, double epsilon); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int head_size, + torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int head_size, + torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox, - int rot_dim, + int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets); void silu_and_mul(torch::Tensor& out, torch::Tensor& input); @@ -60,12 +62,12 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, - int split_k_iters); + int64_t split_k_iters); torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, int split_k_iters, int thx, - int thy); + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy); torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& workspace, @@ -88,9 +90,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits); -int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales); +void cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales); #endif @@ -106,9 +108,9 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int bit); + bool use_exllama, int64_t bit); -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit); +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); @@ -116,28 +118,28 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); -void moe_align_block_size(torch::Tensor topk_ids, int num_experts, - int block_size, torch::Tensor sorted_token_ids, +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM -using fptr_t = uint64_t; +using fptr_t = int64_t; fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, - const std::vector& offsets, int rank, + const std::vector& offsets, int64_t rank, bool full_nvlink); -bool should_custom_ar(torch::Tensor& inp, int max_size, int world_size, +bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); void dispose(fptr_t _fa); -int meta_size(); +int64_t meta_size(); void register_buffer(fptr_t _fa, torch::Tensor& t, const std::vector& handles, const std::vector& offsets); -std::pair, std::vector> get_graph_buffer_ipc_meta( +std::tuple> get_graph_buffer_ipc_meta( fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector& handles, const std::vector>& offsets); diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 69d6dae1c26bc..97184a8735593 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -127,7 +127,7 @@ void rotary_embedding( // [num_tokens, num_heads * head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or // [num_tokens, num_kv_heads * head_size] - int head_size, + int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { int64_t num_tokens = query.numel() / query.size(-1); @@ -138,7 +138,7 @@ void rotary_embedding( int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { @@ -168,9 +168,9 @@ void batched_rotary_embedding( // [num_tokens, num_heads * head_size] torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or // [num_tokens, num_kv_heads * head_size] - int head_size, + int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox, int rot_dim, + bool is_neox, int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets // [num_tokens] ) { int64_t num_tokens = cos_sin_cache_offsets.size(0); @@ -180,7 +180,7 @@ void batched_rotary_embedding( int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); - dim3 block(std::min(num_heads * rot_dim / 2, 512)); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cu index 61de3b37937cc..dd29820144b34 100644 --- a/csrc/punica/punica_ops.cu +++ b/csrc/punica/punica_ops.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -88,7 +88,7 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, } void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, float scale) { + torch::Tensor indicies, int64_t layer_idx, double scale) { CHECK_INPUT(y); CHECK_INPUT(x); CHECK_INPUT(w); @@ -320,7 +320,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, - float scale, int64_t h_in, int64_t h_out, + double scale, int64_t h_in, int64_t h_out, int64_t y_offset) { CHECK_INPUT(y); CHECK_INPUT(x); diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h index 937e2d1d25d4a..5d625d0564f75 100644 --- a/csrc/punica/punica_ops.h +++ b/csrc/punica/punica_ops.h @@ -1,11 +1,11 @@ #pragma once -#include +#include void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, - torch::Tensor indicies, int64_t layer_idx, float scale); + torch::Tensor indicies, int64_t layer_idx, double scale); void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, torch::Tensor indicies, int64_t layer_idx, - float scale, int64_t h_in, int64_t h_out, + double scale, int64_t h_in, int64_t h_out, int64_t y_offset); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp deleted file mode 100644 index 9490ad59cdd5f..0000000000000 --- a/csrc/punica/punica_pybind.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include - -#include "punica_ops.h" - -//====== pybind ====== - -#define DEFINE_pybind(name) m.def(#name, &name, #name); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); - m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, - "dispatch_bgmv_low_level"); -} diff --git a/csrc/punica/torch_bindings.cpp b/csrc/punica/torch_bindings.cpp new file mode 100644 index 0000000000000..894e229b6d9db --- /dev/null +++ b/csrc/punica/torch_bindings.cpp @@ -0,0 +1,18 @@ +#include "registration.h" +#include "punica_ops.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + m.def( + "dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int " + "layer_idx, float scale) -> ()"); + m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv); + + m.def( + "dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w," + "Tensor indicies, int layer_idx," + "float scale, int h_in, int h_out," + "int y_offset) -> ()"); + m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp deleted file mode 100644 index 547823aa1b04e..0000000000000 --- a/csrc/pybind.cpp +++ /dev/null @@ -1,114 +0,0 @@ -#include "cache.h" -#include "cuda_utils.h" -#include "ops.h" -#include - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // vLLM custom ops - pybind11::module ops = m.def_submodule("ops", "vLLM custom operators"); - - // Attention ops - ops.def("paged_attention_v1", &paged_attention_v1, - "Compute the attention between an input query and the cached " - "keys/values using PagedAttention."); - ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2."); - - // Activation ops - ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); - ops.def("gelu_and_mul", &gelu_and_mul, - "Activation function used in GeGLU with `none` approximation."); - ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, - "Activation function used in GeGLU with `tanh` approximation."); - ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2."); - ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation."); - - // Layernorm - ops.def("rms_norm", &rms_norm, - "Apply Root Mean Square (RMS) Normalization to the input tensor."); - - ops.def("fused_add_rms_norm", &fused_add_rms_norm, - "In-place fused Add and RMS Normalization"); - - // Rotary embedding - ops.def("rotary_embedding", &rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); - - ops.def("batched_rotary_embedding", &batched_rotary_embedding, - "Apply GPT-NeoX or GPT-J style rotary embedding to query and key " - "(supports multiple loras)"); - -// Quantization ops -#ifndef USE_ROCM - ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM"); - ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); - ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); - ops.def("marlin_gemm", &marlin_gemm, - "Marlin (Dense) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, - "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, - "gptq_marlin Optimized Quantized GEMM for GPTQ"); - ops.def("gptq_marlin_repack", &gptq_marlin_repack, - "gptq_marlin repack from GPTQ"); - ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); - ops.def("cutlass_scaled_mm_dq", &cutlass_scaled_mm_dq, - "CUTLASS w8a8 GEMM, supporting symmetric per-tensor or " - "per-row/column quantization."); -#endif - - ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); - ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); - ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, - "Compute FP8 quantized tensor for given scaling factor"); - ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, - "Compute FP8 quantized tensor and scaling factor"); - ops.def("moe_align_block_size", &moe_align_block_size, - "Aligning the number of tokens to be processed by each expert such " - "that it is divisible by the block size."); - - ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, - "Compute int8 quantized tensor for given scaling factor"); - - ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant, - "Compute int8 quantized tensor and scaling factor"); - - // Cache ops - pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); - cache_ops.def("swap_blocks", &swap_blocks, - "Swap in (out) the cache blocks from src to dst"); - cache_ops.def("copy_blocks", ©_blocks, - "Copy the cache blocks from src to dst"); - cache_ops.def("reshape_and_cache", &reshape_and_cache, - "Reshape the key and value tensors and cache them"); - cache_ops.def("reshape_and_cache_flash", &reshape_and_cache_flash, - "Reshape the key and value tensors and cache them"); - cache_ops.def("convert_fp8", &convert_fp8, - "Convert the key and value cache to fp8 data type"); - - // Cuda utils - pybind11::module cuda_utils = - m.def_submodule("cuda_utils", "vLLM cuda utils"); - cuda_utils.def("get_device_attribute", &get_device_attribute, - "Gets the specified device attribute."); - - cuda_utils.def("get_max_shared_memory_per_block_device_attribute", - &get_max_shared_memory_per_block_device_attribute, - "Gets the maximum shared memory per block device attribute."); - -#ifndef USE_ROCM - // Custom all-reduce kernels - pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); - custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); - custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar"); - custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg"); - custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg"); - custom_ar.def("dispose", &dispose, "dispose"); - custom_ar.def("meta_size", &meta_size, "meta_size"); - custom_ar.def("register_buffer", ®ister_buffer, "register_buffer"); - custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, - "get_graph_buffer_ipc_meta"); - custom_ar.def("register_graph_buffers", ®ister_graph_buffers, - "register_graph_buffers"); -#endif -} diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 255844eec56d4..8fb9856800867 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include #include diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index bb8e5bbb23d7f..6d6da5f3d8746 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -7,7 +7,7 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023} } */ -#include +#include #include #include "dequantize.cuh" @@ -435,8 +435,8 @@ __global__ void __launch_bounds__(64) torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _scaling_factors, - torch::Tensor _zeros, int split_k_iters, int thx, - int thy) { + torch::Tensor _zeros, int64_t split_k_iters, + int64_t thx, int64_t thy) { int in_c = _kernel.size(0); int qout_c = _kernel.size(1); int out_c = qout_c * 8; @@ -491,7 +491,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, - int split_k_iters) { + int64_t split_k_iters) { int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index 280b0327111da..aa9511daa2772 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,5 +1,5 @@ #include -#include +#include #include #include "../../dispatch_utils.h" diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu index 088fee4783faa..23a8b4070b70e 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu @@ -1,5 +1,5 @@ #include -#include +#include #include diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu index 8fc4ba662ecdd..a99802153643a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu @@ -4,7 +4,7 @@ #if defined CUDA_VERSION && CUDA_VERSION >= 12000 -#include +#include #include diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu index eb532f2ac7a9b..423e64a4932e2 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu @@ -1,7 +1,7 @@ #include #include -#include +#include void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 55be3305a9b8c..8c5b693bf6ed7 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -1,5 +1,5 @@ #include -#include +#include #include #include diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 480c4986c3821..785f1a09c1900 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -6,7 +6,7 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa #include #include -#include +#include #include #include #include @@ -1823,7 +1823,7 @@ void shuffle_exllama_weight(uint32_t* q_weight, int* q_perm, int height, torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int bit) { + bool use_exllama, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); @@ -1845,7 +1845,7 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, return c; } -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int bit) { +void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit) { const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); vllm::gptq::shuffle_exllama_weight( (uint32_t*)q_weight.data_ptr(), diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index c573b9041065b..0beb9de14c687 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1867,4 +1867,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, return c; } -#endif \ No newline at end of file +#endif diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh index ba5368ea8835f..42af44951efda 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cuh +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index 03d66cecedf1f..d124c0149912d 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -15,7 +15,7 @@ * limitations under the License. */ -#include +#include #include #include diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 686dd7851e6af..b5effc3055441 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -16,7 +16,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include #include #include diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 1b339fa4b392b..40baac6108695 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -1,5 +1,4 @@ #include -#include #include #include #include diff --git a/csrc/registration.h b/csrc/registration.h new file mode 100644 index 0000000000000..e5396e9a8b137 --- /dev/null +++ b/csrc/registration.h @@ -0,0 +1,22 @@ +#pragma once + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp new file mode 100644 index 0000000000000..df2603544c85a --- /dev/null +++ b/csrc/torch_bindings.cpp @@ -0,0 +1,283 @@ +#include "cache.h" +#include "cuda_utils.h" +#include "ops.h" +#include "registration.h" + +#include + +// Note on op signatures: +// The X_meta signatures are for the meta functions corresponding to op X. +// They must be kept in sync with the signature for X. Generally, only +// functions that return Tensors require a meta function. +// +// See the following links for detailed docs on op registration and function +// schemas. +// https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#annotations + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + // vLLM custom ops + + // Attention ops + // Compute the attention between an input query and the cached + // keys/values using PagedAttention. + ops.def( + "paged_attention_v1(" + " Tensor! out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); + + // PagedAttention V2. + ops.def( + "paged_attention_v2(" + " Tensor! out, Tensor exp_sums, Tensor max_logits," + " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float kv_scale, int tp_rank," + " int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); + + // Activation ops + // Activation function used in SwiGLU. + ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + // Activation function used in GeGLU with `none` approximation. + ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + // Activation function used in GeGLU with `tanh` approximation. + ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + // GELU implementation used in GPT-2. + ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_new", torch::kCUDA, &gelu_new); + + // Approximate GELU implementation. + ops.def("gelu_fast(Tensor! out, Tensor input) -> ()"); + ops.impl("gelu_fast", torch::kCUDA, &gelu_fast); + + // Layernorm + // Apply Root Mean Square (RMS) Normalization to the input tensor. + ops.def( + "rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> " + "()"); + ops.impl("rms_norm", torch::kCUDA, &rms_norm); + + // In-place fused Add and RMS Normalization. + ops.def( + "fused_add_rms_norm(Tensor! input, Tensor! residual, Tensor weight, " + "float epsilon) -> ()"); + ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); + + // Rotary embedding + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. + ops.def( + "rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); + + // Apply GPT-NeoX or GPT-J style rotary embedding to query and key + // (supports multiple loras). + ops.def( + "batched_rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox," + " int rot_dim," + " Tensor cos_sin_cache_offsets) -> ()"); + ops.impl("batched_rotary_embedding", torch::kCUDA, &batched_rotary_embedding); + + // Quantization ops +#ifndef USE_ROCM + // Quantized GEMM for AQLM. + ops.def("aqlm_gemm", &aqlm_gemm); + ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm); + + // Decompression method for AQLM. + ops.def("aqlm_dequant", &aqlm_dequant); + ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant); + + // Quantized GEMM for AWQ. + ops.def("awq_gemm", &awq_gemm); + ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); + + // Dequantization for AWQ. + ops.def("awq_dequantize", &awq_dequantize); + ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); + + // Marlin (Dense) Optimized Quantized GEMM for GPTQ. + ops.def("marlin_gemm", &marlin_gemm); + ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); + + // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. + ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); + ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); + + // gptq_marlin Optimized Quantized GEMM for GPTQ. + ops.def("gptq_marlin_gemm", &gptq_marlin_gemm); + ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); + + // gptq_marlin repack from GPTQ. + ops.def("gptq_marlin_repack", &gptq_marlin_repack); + ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); + + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_dq(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales) -> ()"); + ops.impl("cutlass_scaled_mm_dq", torch::kCUDA, &cutlass_scaled_mm_dq); +#endif + + // Quantized GEMM for GPTQ. + ops.def("gptq_gemm", &gptq_gemm); + ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm); + + // Post processing for GPTQ. + ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); + ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); + + // Quantized GEMM for SqueezeLLM. + ops.def( + "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor " + "lookup_table) -> ()"); + ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm); + + // Compute FP8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); + ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); + + // Compute FP8 quantized tensor and scaling factor. + ops.def( + "dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "()"); + ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant); + + // Aligning the number of tokens to be processed by each expert such + // that it is divisible by the block size. + ops.def( + "moe_align_block_size(Tensor topk_ids, int num_experts," + " int block_size, Tensor! sorted_token_ids," + " Tensor! experts_ids," + " Tensor! num_tokens_post_pad) -> ()"); + ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> " + "()"); + ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant); + + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> " + "()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCUDA, + &dynamic_scaled_int8_quant); +} + +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { + // Cache ops + // Swap in (out) the cache blocks from src to dst. + cache_ops.def( + "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); + cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); + + // Copy the cache blocks from src to dst. + cache_ops.def( + "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " + "block_mapping) -> ()"); + cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks); + + // Reshape the key and value tensors and cache them. + cache_ops.def( + "reshape_and_cache(Tensor key, Tensor value," + " Tensor! key_cache, Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float kv_scale) -> ()"); + cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); + + // Reshape the key and value tensors and cache them. + cache_ops.def( + "reshape_and_cache_flash(Tensor key, Tensor value," + " Tensor! key_cache," + " Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype) -> ()"); + cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, + &reshape_and_cache_flash); + + // Convert the key and value cache to fp8 data type. + cache_ops.def( + "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str " + "kv_cache_dtype) -> ()"); + cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8); +} + +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { + // Cuda utils + + // Gets the specified device attribute. + cuda_utils.def("get_device_attribute", &get_device_attribute); + cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute); + + // Gets the maximum shared memory per block device attribute. + cuda_utils.def("get_max_shared_memory_per_block_device_attribute", + &get_max_shared_memory_per_block_device_attribute); + cuda_utils.impl("get_max_shared_memory_per_block_device_attribute", + torch::kCUDA, + &get_max_shared_memory_per_block_device_attribute); +} + +#ifndef USE_ROCM +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { + // Custom all-reduce kernels + custom_ar.def("init_custom_ar", &init_custom_ar); + custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + + custom_ar.def("should_custom_ar", &should_custom_ar); + custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar); + + custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); + custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); + + custom_ar.def( + "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> " + "()"); + custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg); + + custom_ar.def("dispose", &dispose); + custom_ar.impl("dispose", torch::kCPU, &dispose); + + custom_ar.def("meta_size", &meta_size); + custom_ar.impl("meta_size", torch::kCPU, &meta_size); + + custom_ar.def("register_buffer", ®ister_buffer); + custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer); + + custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); + custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU, + &get_graph_buffer_ipc_meta); + + custom_ar.def("register_graph_buffers", ®ister_graph_buffers); + custom_ar.impl("register_graph_buffers", torch::kCPU, + ®ister_graph_buffers); +} +#endif + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/docs/source/conf.py b/docs/source/conf.py index f1a7013edd332..ee0f6c53bd1b9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -92,6 +92,7 @@ def setup(app): "vllm._C", "PIL", "numpy", + 'triton' "tqdm", "tensorizer", ] diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 24fa83df7d751..5d3f55be1271f 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -89,7 +89,11 @@ Alongside each architecture, we include some popular models that use it. - ✅︎ * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - - :code:`llava-hf/llava-1.5-7b-hf`\*, :code:`llava-hf/llava-1.5-13b-hf`\*, etc. + - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. + - + * - :code:`LlavaNextForConditionalGeneration` + - LLaVA-NeXT + - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - * - :code:`MiniCPMForCausalLM` - MiniCPM diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 52afda747aab8..33aa8246b2e60 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -3,7 +3,7 @@ Using VLMs ========== -This document shows you how to run and serve Vision Language Models (VLMs) using vLLM. +vLLM provides experimental support for Vision Language Models (VLMs). This document shows you how to run and serve these models using vLLM. Engine Arguments ---------------- @@ -16,6 +16,13 @@ The following :ref:`engine arguments ` are specific to VLMs: :prog: -m vllm.entrypoints.openai.api_server :nodefaultconst: +.. important:: + Currently, the support for vision language models on vLLM has the following limitations: + + * Only single image input is supported per text prompt. + * Dynamic ``image_input_shape`` is not supported: the input image will be resized to the static ``image_input_shape``. This means model output might not exactly match the huggingface implementation. + We are continuously improving user & developer experience for VLMs. Please raise an issue on GitHub if you have any feedback or feature requests. + Offline Batched Inference ------------------------- @@ -31,7 +38,7 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` image_feature_size=576, ) -For now, we only support a single image per text prompt. To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: * ``prompt``: The prompt should have a number of ```` tokens equal to ``image_feature_size``. * ``multi_modal_data``: This should be an instance of :class:`~vllm.multimodal.image.ImagePixelData` or :class:`~vllm.multimodal.image.ImageFeatureData`. @@ -54,3 +61,69 @@ For now, we only support a single image per text prompt. To pass an image to the print(generated_text) A code example can be found in `examples/llava_example.py `_. + +Online OpenAI Vision API Compatible Inference +---------------------------------------------- + +You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API `_. + +.. note:: + Currently, vLLM supports only **single** ``image_url`` input per ``messages``. Support for multi-image inputs will be + added in the future. + +Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with vLLM API server. + +.. important:: + Since OpenAI Vision API is based on `Chat `_ API, a chat template + is **required** to launch the API server if the model's tokenizer does not come with one. In this example, we use the + HuggingFace Llava chat template that you can find in the example folder `here `_. + +.. code-block:: bash + + python -m vllm.entrypoints.openai.api_server \ + --model llava-hf/llava-1.5-7b-hf \ + --image-input-type pixel_values \ + --image-token-id 32000 \ + --image-input-shape 1,3,336,336 \ + --image-feature-size 576 \ + --chat-template template_llava.jinja + +To consume the server, you can use the OpenAI client like in the example below: + +.. code-block:: python + + from openai import OpenAI + openai_api_key = "EMPTY" + openai_api_base = "http://localhost:8000/v1" + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + chat_response = client.chat.completions.create( + model="llava-hf/llava-1.5-7b-hf", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + }, + }, + ], + }], + ) + print("Chat response:", chat_response) + +.. note:: + + By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable: + + .. code-block:: shell + + export VLLM_IMAGE_FETCH_TIMEOUT= + +.. note:: + The prompt formatting with the image token ```` is not needed when serving VLMs with the API server since the prompt will be + processed automatically by the server. diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a912949352b86..6248d84683753 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -30,6 +30,8 @@ Please see the [OpenAI API Reference](https://platform.openai.com/docs/api-refer - Chat: `tools`, and `tool_choice`. - Completions: `suffix`. +vLLM also provides experimental support for OpenAI Vision API compatible inference. See more details in [Using VLMs](../models/vlm.rst). + ## Extra Parameters vLLM supports a set of parameters that are not part of the OpenAI API. In order to use them, you can pass them as extra parameters in the OpenAI client. @@ -120,4 +122,4 @@ It is the callers responsibility to prompt the model with the tool information, vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. -Please refer to the OpenAI API reference documentation for more information. \ No newline at end of file +Please refer to the OpenAI API reference documentation for more information. diff --git a/examples/template_llava.jinja b/examples/template_llava.jinja new file mode 100644 index 0000000000000..6a902ee167725 --- /dev/null +++ b/examples/template_llava.jinja @@ -0,0 +1,23 @@ +{%- if messages[0]['role'] == 'system' -%} + {%- set system_message = messages[0]['content'] -%} + {%- set messages = messages[1:] -%} +{%- else -%} + {% set system_message = '' -%} +{%- endif -%} + +{{ bos_token + system_message }} +{%- for message in messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {%- endif -%} + + {%- if message['role'] == 'user' -%} + {{ 'USER: ' + message['content'] + '\n' }} + {%- elif message['role'] == 'assistant' -%} + {{ 'ASSISTANT: ' + message['content'] + eos_token + '\n' }} + {%- endif -%} +{%- endfor -%} + +{%- if add_generation_prompt -%} + {{ 'ASSISTANT:' }} +{% endif %} diff --git a/setup.py b/setup.py index f7d465b60c153..339b0ad6de2d1 100644 --- a/setup.py +++ b/setup.py @@ -60,7 +60,7 @@ def remove_prefix(text, prefix): class CMakeExtension(Extension): def __init__(self, name: str, cmake_lists_dir: str = '.', **kwa) -> None: - super().__init__(name, sources=[], **kwa) + super().__init__(name, sources=[], py_limited_api=True, **kwa) self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 7d8117447ca0a..805b8883b9d94 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -43,16 +43,14 @@ def test_models( if backend_by_env_var == "FLASHINFER" and enforce_eager is False: pytest.skip("Skipping non-eager test for FlashInferBackend.") - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - del hf_model - - vllm_model = vllm_runner(model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - del vllm_model + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + with vllm_runner(model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 47d582c726c66..48d6091282b89 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -40,21 +40,19 @@ def test_models( enable_chunked_prefill = True max_num_batched_tokens = chunked_prefill_token_size - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - del hf_model - - vllm_model = vllm_runner( - model, - dtype=dtype, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=enable_chunked_prefill, - tensor_parallel_size=tensor_parallel_size, - enforce_eager=enforce_eager, - max_num_seqs=max_num_seqs, - ) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - del vllm_model + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + with vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 29a4c39cd25a1..7f20b2d934942 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -43,21 +43,19 @@ def test_chunked_prefill_recompute( enable_chunked_prefill = True max_num_batched_tokens = chunked_prefill_token_size - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - del hf_model + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - vllm_model = vllm_runner( - model, - dtype=dtype, - max_num_batched_tokens=max_num_batched_tokens, - enable_chunked_prefill=enable_chunked_prefill, - max_num_seqs=max_num_seqs, - ) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) - del vllm_model + with vllm_runner( + model, + dtype=dtype, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=enable_chunked_prefill, + max_num_seqs=max_num_seqs, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] @@ -82,21 +80,19 @@ def test_preemption( ) -> None: """By default, recompute preemption is enabled""" - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - del hf_model + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - vllm_model = vllm_runner( - model, - dtype=dtype, - disable_log_stats=False, - ) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) - total_preemption = ( - vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) - del vllm_model + with vllm_runner( + model, + dtype=dtype, + disable_log_stats=False, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + total_preemption = ( + vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] @@ -137,24 +133,22 @@ def test_swap( ) -> None: """Use beam search enables swapping.""" example_prompts = example_prompts[:1] - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, - max_tokens) - del hf_model - - vllm_model = vllm_runner( - model, - dtype=dtype, - swap_space=10, - disable_log_stats=False, - ) - vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, max_tokens) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) - total_preemption = ( - vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) - del vllm_model + + with vllm_runner( + model, + dtype=dtype, + swap_space=10, + disable_log_stats=False, + ) as vllm_model: + vllm_outputs = vllm_model.generate_beam_search(example_prompts, + beam_width, max_tokens) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + total_preemption = ( + vllm_model.model.llm_engine.scheduler.num_cumulative_preemption) for i in range(len(example_prompts)): hf_output_ids, _ = hf_outputs[i] @@ -199,28 +193,28 @@ def test_swap_infeasible( decode_blocks = max_tokens // BLOCK_SIZE example_prompts = example_prompts[:1] - vllm_model = vllm_runner( - model, - dtype=dtype, - swap_space=10, - block_size=BLOCK_SIZE, - # Since beam search have more than 1 sequence, prefill + decode blocks - # are not enough to finish. - num_gpu_blocks_override=prefill_blocks + decode_blocks, - max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, - ) - sampling_params = SamplingParams(n=beam_width, - use_beam_search=True, - temperature=0.0, - max_tokens=max_tokens, - ignore_eos=True) - req_outputs = vllm_model.model.generate( - example_prompts, - sampling_params=sampling_params, - ) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) - del vllm_model + with vllm_runner( + model, + dtype=dtype, + swap_space=10, + block_size=BLOCK_SIZE, + # Since beam search have more than 1 sequence, prefill + + # decode blocks are not enough to finish. + num_gpu_blocks_override=prefill_blocks + decode_blocks, + max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, + ) as vllm_model: + sampling_params = SamplingParams(n=beam_width, + use_beam_search=True, + temperature=0.0, + max_tokens=max_tokens, + ignore_eos=True) + req_outputs = vllm_model.model.generate( + example_prompts, + sampling_params=sampling_params, + ) + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) + # Verify the request is ignored and not hang. assert req_outputs[0].outputs[0].finish_reason == "length" @@ -239,25 +233,26 @@ def test_preemption_infeasible( BLOCK_SIZE = 16 prefill_blocks = 2 decode_blocks = max_tokens // BLOCK_SIZE - vllm_model = vllm_runner( - model, - dtype=dtype, - block_size=BLOCK_SIZE, - # Not enough gpu blocks to complete a single sequence. - # preemption should happen, and the sequence should be - # ignored instead of hanging forever. - num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, - max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), - ) - sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) - req_outputs = vllm_model.model.generate( - example_prompts, - sampling_params=sampling_params, - ) + with vllm_runner( + model, + dtype=dtype, + block_size=BLOCK_SIZE, + # Not enough gpu blocks to complete a single sequence. + # preemption should happen, and the sequence should be + # ignored instead of hanging forever. + num_gpu_blocks_override=prefill_blocks + decode_blocks // 2, + max_model_len=((prefill_blocks + decode_blocks // 2) * BLOCK_SIZE), + ) as vllm_model: + sampling_params = SamplingParams(max_tokens=max_tokens, + ignore_eos=True) + req_outputs = vllm_model.model.generate( + example_prompts, + sampling_params=sampling_params, + ) + + assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < + ARTIFICIAL_PREEMPTION_MAX_CNT) - assert (vllm_model.model.llm_engine.scheduler.artificial_preempt_cnt < - ARTIFICIAL_PREEMPTION_MAX_CNT) - del vllm_model # Verify the request is ignored and not hang. for req_output in req_outputs: outputs = req_output.outputs diff --git a/tests/conftest.py b/tests/conftest.py index 1a7037eb2f290..e0680467d78b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -354,7 +354,10 @@ def generate_greedy_logprobs_limit( def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) - def __del__(self): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): del self.model cleanup() @@ -490,7 +493,10 @@ def encode(self, prompts: List[str]) -> List[List[float]]: outputs.append(embedding) return outputs - def __del__(self): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): del self.model cleanup() diff --git a/tests/distributed/test_basic_distributed_correctness.py b/tests/distributed/test_basic_distributed_correctness.py index 3ba5cea389c2f..eb423aef230cb 100644 --- a/tests/distributed/test_basic_distributed_correctness.py +++ b/tests/distributed/test_basic_distributed_correctness.py @@ -42,18 +42,16 @@ def test_models( backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND) enforce_eager = backend_by_env_var == "FLASHINFER" - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - del hf_model - - vllm_model = vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - enforce_eager=enforce_eager, - distributed_executor_backend=distributed_executor_backend) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - del vllm_model + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=2, + enforce_eager=enforce_eager, + distributed_executor_backend=distributed_executor_backend + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] diff --git a/tests/distributed/test_chunked_prefill_distributed.py b/tests/distributed/test_chunked_prefill_distributed.py index db938cc613c6b..4e4e468c4377a 100644 --- a/tests/distributed/test_chunked_prefill_distributed.py +++ b/tests/distributed/test_chunked_prefill_distributed.py @@ -45,21 +45,19 @@ def test_models( enable_chunked_prefill = True max_num_batched_tokens = chunked_prefill_token_size - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - del hf_model + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - vllm_model = vllm_runner( - model, - dtype=dtype, - tensor_parallel_size=2, - max_num_seqs=max_num_seqs, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - ) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - del vllm_model + with vllm_runner( + model, + dtype=dtype, + tensor_parallel_size=2, + max_num_seqs=max_num_seqs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] diff --git a/tests/engine/test_stop_reason.py b/tests/engine/test_stop_reason.py index 7b886507c04f2..b0bd6c4aa95d3 100644 --- a/tests/engine/test_stop_reason.py +++ b/tests/engine/test_stop_reason.py @@ -19,9 +19,8 @@ @pytest.fixture def vllm_model(vllm_runner): - vllm_model = vllm_runner(MODEL) - yield vllm_model - del vllm_model + with vllm_runner(MODEL) as vllm_model: + yield vllm_model def test_stop_reason(vllm_model, example_prompts): diff --git a/tests/engine/test_stop_strings.py b/tests/engine/test_stop_strings.py index 6b747beb4b543..1584b85aeb064 100644 --- a/tests/engine/test_stop_strings.py +++ b/tests/engine/test_stop_strings.py @@ -10,7 +10,8 @@ @pytest.fixture(scope="session") def vllm_model(vllm_runner): - return vllm_runner(MODEL) + with vllm_runner(MODEL) as vllm_model: + yield vllm_model @pytest.mark.skip_global_cleanup diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index b7d0946ba7244..d0fe08ae0ddd2 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -478,8 +478,6 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI, temperature=0.0, ) single_output = single_completion.choices[0].text - single_usage = single_completion.usage - stream = await client.completions.create(model=model_name, prompt=prompt, max_tokens=5, @@ -495,7 +493,6 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI, assert finish_reason_count == 1 assert chunk.choices[0].finish_reason == "length" assert chunk.choices[0].text - assert chunk.usage == single_usage assert "".join(chunks) == single_output @@ -550,6 +547,138 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == output +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + ["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"], +) +async def test_chat_completion_stream_options(server, + client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "What is the capital of France?" + }] + + # Test stream=True, stream_options={"include_usage": False} + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + stream=True, + stream_options={"include_usage": False}) + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options={"include_usage": True} + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + stream=True, + stream_options={"include_usage": True}) + + async for chunk in stream: + if chunk.choices[0].finish_reason is None: + assert chunk.usage is None + else: + assert chunk.usage is None + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=False, stream_options={"include_usage": None} + with pytest.raises(BadRequestError): + await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}) + + # Test stream=False, stream_options={"include_usage": True} + with pytest.raises(BadRequestError): + await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + ["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"], +) +async def test_completion_stream_options(server, client: openai.AsyncOpenAI, + model_name: str): + prompt = "What is the capital of France?" + + # Test stream=True, stream_options={"include_usage": False} + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={"include_usage": False}) + async for chunk in stream: + assert chunk.usage is None + + # Test stream=True, stream_options={"include_usage": True} + stream = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={"include_usage": True}) + async for chunk in stream: + if chunk.choices[0].finish_reason is None: + assert chunk.usage is None + else: + assert chunk.usage is None + final_chunk = await stream.__anext__() + assert final_chunk.usage is not None + assert final_chunk.usage.prompt_tokens > 0 + assert final_chunk.usage.completion_tokens > 0 + assert final_chunk.usage.total_tokens == ( + final_chunk.usage.prompt_tokens + + final_chunk.usage.completion_tokens) + assert final_chunk.choices == [] + + # Test stream=False, stream_options={"include_usage": None} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": None}) + + # Test stream=False, stream_options={"include_usage": True} + with pytest.raises(BadRequestError): + await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=False, + stream_options={"include_usage": True}) + + @pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter @@ -1343,106 +1472,5 @@ async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 17 -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME], -) -async def test_stream_options(server, client: openai.AsyncOpenAI, - model_name: str): - prompt = "What is the capital of France?" - - # Test stream=True, stream_options=None - stream = await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options=None, - ) - chunks = [] - async for chunk in stream: - chunks.append(chunk.choices[0].text) - assert len(chunks) > 0 - assert "usage" not in chunk - - # Test stream=True, stream_options={"include_usage": False} - stream = await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={"include_usage": False}, - ) - chunks = [] - async for chunk in stream: - chunks.append(chunk.choices[0].text) - assert len(chunks) > 0 - assert "usage" not in chunk - - # Test stream=True, stream_options={"include_usage": True} - stream = await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=True, - stream_options={"include_usage": True}, - ) - chunks = [] - finish_reason_count = 0 - async for chunk in stream: - if chunk.choices[0].finish_reason is None: - assert chunk.usage is None - chunks.append(chunk.choices[0].text) - else: - assert chunk.usage is None - finish_reason_count += 1 - - # The last message should have usage and no choices - last_message = await stream.__anext__() - assert last_message.usage is not None - assert last_message.usage.prompt_tokens > 0 - assert last_message.usage.completion_tokens > 0 - assert last_message.usage.total_tokens == ( - last_message.usage.prompt_tokens + - last_message.usage.completion_tokens) - assert last_message.choices == [] - - # Test stream=False, stream_options={"include_usage": None} - with pytest.raises(BadRequestError): - await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": None}, - ) - - # Test stream=False, stream_options={"include_usage": False} - with pytest.raises(BadRequestError): - await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": False}, - ) - - # Test stream=False, stream_options={"include_usage": True} - with pytest.raises(BadRequestError): - await client.completions.create( - model=model_name, - prompt=prompt, - max_tokens=5, - temperature=0.0, - stream=False, - stream_options={"include_usage": True}, - ) - - if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/entrypoints/test_openai_vision.py b/tests/entrypoints/test_openai_vision.py new file mode 100644 index 0000000000000..cc03b04e0b0e0 --- /dev/null +++ b/tests/entrypoints/test_openai_vision.py @@ -0,0 +1,286 @@ +from pathlib import Path +from typing import Dict + +import openai +import pytest +import pytest_asyncio +import ray + +from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64 + +from ..utils import ServerRunner + +MODEL_NAME = "llava-hf/llava-1.5-7b-hf" +LLAVA_CHAT_TEMPLATE = (Path(__file__).parent.parent.parent / + "examples/template_llava.jinja") +assert LLAVA_CHAT_TEMPLATE.exists() +# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) +TEST_IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +] + +pytestmark = pytest.mark.openai + + +@pytest.fixture(scope="module") +def server(): + ray.init() + server_runner = ServerRunner.remote([ + "--model", + MODEL_NAME, + "--dtype", + "bfloat16", + "--max-model-len", + "4096", + "--enforce-eager", + "--image-input-type", + "pixel_values", + "--image-token-id", + "32000", + "--image-input-shape", + "1,3,336,336", + "--image-feature-size", + "576", + "--chat-template", + str(LLAVA_CHAT_TEMPLATE), + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + +@pytest.fixture(scope="session") +def client(): + client = openai.AsyncOpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + yield client + + +@pytest_asyncio.fixture(scope="session") +async def base64_encoded_image() -> Dict[str, str]: + return { + image_url: + encode_image_base64(await ImageFetchAiohttp.fetch_image(image_url)) + for image_url in TEST_IMAGE_URLS + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_single_chat_session_image(server, client: openai.AsyncOpenAI, + model_name: str, image_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" + assert chat_completion.usage == openai.types.CompletionUsage( + completion_tokens=10, prompt_tokens=596, total_tokens=606) + + message = choice.message + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_single_chat_session_image_base64encoded( + server, client: openai.AsyncOpenAI, model_name: str, image_url: str, + base64_encoded_image: Dict[str, str]): + + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": + f"data:image/jpeg;base64,{base64_encoded_image[image_url]}" + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" + assert chat_completion.usage == openai.types.CompletionUsage( + completion_tokens=10, prompt_tokens=596, total_tokens=606) + + message = choice.message + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_chat_streaming_image(server, client: openai.AsyncOpenAI, + model_name: str, image_url: str): + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + # test single completion + chat_completion = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + ) + output = chat_completion.choices[0].message.content + stop_reason = chat_completion.choices[0].finish_reason + + # test streaming + stream = await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + stream=True, + ) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + delta = chunk.choices[0].delta + if delta.role: + assert delta.role == "assistant" + if delta.content: + chunks.append(delta.content) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + # finish reason should only return in last block + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert delta.content + assert "".join(chunks) == output + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_multi_image_input(server, client: openai.AsyncOpenAI, + model_name: str, image_url: str): + + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + with pytest.raises(openai.BadRequestError): # test multi-image input + await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + ) + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + completion = completion.choices[0].text + assert completion is not None and len(completion) >= 0 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/kernels/test_int8_quant.py b/tests/kernels/test_int8_quant.py index aab7af9d2cbf6..0daf7439468aa 100644 --- a/tests/kernels/test_int8_quant.py +++ b/tests/kernels/test_int8_quant.py @@ -1,7 +1,8 @@ import pytest import torch -from vllm._C import ops +# ruff: noqa: F401 +import vllm._C DTYPES = [torch.half, torch.bfloat16, torch.float] HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192, @@ -33,7 +34,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda") scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda") - ops.dynamic_scaled_int8_quant(ops_out, x, scales_out) + torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out) assert torch.allclose(scales_out, scales) assert torch.allclose(torch_out, ops_out, @@ -60,6 +61,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, out2 = torch.empty_like(x, dtype=torch.int8) scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") - ops.static_scaled_int8_quant(out2, x, scale_argument) + torch.ops._C.static_scaled_int8_quant(out2, x, scale_argument) assert torch.allclose(out1, out2, atol=1) # big atol to account for rounding errors diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index 892f6081e2aaa..4ff9715b4ca8d 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -1,12 +1,13 @@ from collections import OrderedDict +import pytest from torch import nn from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule from vllm.utils import LRUCache -def test_parse_fine_tuned_lora_name(): +def test_parse_fine_tuned_lora_name_valid(): fixture = { ("base_model.model.lm_head.lora_A.weight", "lm_head", True), ("base_model.model.lm_head.lora_B.weight", "lm_head", False), @@ -35,6 +36,17 @@ def test_parse_fine_tuned_lora_name(): assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name) +def test_parse_fine_tuned_lora_name_invalid(): + fixture = { + "weight", + "base_model.weight", + "base_model.model.weight", + } + for name in fixture: + with pytest.raises(ValueError, match="unsupported LoRA weight"): + parse_fine_tuned_lora_name(name) + + def test_replace_submodule(): model = nn.Sequential( OrderedDict([ diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index e0aa14f165c2d..c1164739eee31 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -23,23 +23,25 @@ def test_metric_counter_prompt_tokens( dtype: str, max_tokens: int, ) -> None: - vllm_model = vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4) - tokenizer = vllm_model.model.get_tokenizer() - prompt_token_counts = [len(tokenizer.encode(p)) for p in example_prompts] - # This test needs at least 2 prompts in a batch of different lengths to - # verify their token count is correct despite padding. - assert len(example_prompts) > 1, "at least 2 prompts are required" - assert prompt_token_counts[0] != prompt_token_counts[1], ( - "prompts of different lengths are required") - vllm_prompt_token_count = sum(prompt_token_counts) - - _ = vllm_model.generate_greedy(example_prompts, max_tokens) - stat_logger = vllm_model.model.llm_engine.stat_logger - metric_count = stat_logger.metrics.counter_prompt_tokens.labels( - **stat_logger.labels)._value.get() + with vllm_runner(model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.4) as vllm_model: + tokenizer = vllm_model.model.get_tokenizer() + prompt_token_counts = [ + len(tokenizer.encode(p)) for p in example_prompts + ] + # This test needs at least 2 prompts in a batch of different lengths to + # verify their token count is correct despite padding. + assert len(example_prompts) > 1, "at least 2 prompts are required" + assert prompt_token_counts[0] != prompt_token_counts[1], ( + "prompts of different lengths are required") + vllm_prompt_token_count = sum(prompt_token_counts) + + _ = vllm_model.generate_greedy(example_prompts, max_tokens) + stat_logger = vllm_model.model.llm_engine.stat_logger + metric_count = stat_logger.metrics.counter_prompt_tokens.labels( + **stat_logger.labels)._value.get() assert vllm_prompt_token_count == metric_count, ( f"prompt token count: {vllm_prompt_token_count!r}\n" @@ -56,22 +58,22 @@ def test_metric_counter_generation_tokens( dtype: str, max_tokens: int, ) -> None: - vllm_model = vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - tokenizer = vllm_model.model.get_tokenizer() - stat_logger = vllm_model.model.llm_engine.stat_logger - metric_count = stat_logger.metrics.counter_generation_tokens.labels( - **stat_logger.labels)._value.get() - vllm_generation_count = 0 - for i in range(len(example_prompts)): - vllm_output_ids, vllm_output_str = vllm_outputs[i] - prompt_ids = tokenizer.encode(example_prompts[i]) - # vllm_output_ids contains both prompt tokens and generation tokens. - # We're interested only in the count of the generation tokens. - vllm_generation_count += len(vllm_output_ids) - len(prompt_ids) + with vllm_runner(model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.4) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + tokenizer = vllm_model.model.get_tokenizer() + stat_logger = vllm_model.model.llm_engine.stat_logger + metric_count = stat_logger.metrics.counter_generation_tokens.labels( + **stat_logger.labels)._value.get() + vllm_generation_count = 0 + for i in range(len(example_prompts)): + vllm_output_ids, vllm_output_str = vllm_outputs[i] + prompt_ids = tokenizer.encode(example_prompts[i]) + # vllm_output_ids contains both prompt tokens and generation tokens. + # We're interested only in the count of the generation tokens. + vllm_generation_count += len(vllm_output_ids) - len(prompt_ids) assert vllm_generation_count == metric_count, ( f"generation token count: {vllm_generation_count!r}\n" @@ -85,15 +87,13 @@ def test_metric_counter_generation_tokens( [None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]]) def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, served_model_name: List[str]) -> None: - vllm_model = vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.3, - served_model_name=served_model_name) - stat_logger = vllm_model.model.llm_engine.stat_logger - metrics_tag_content = stat_logger.labels["model_name"] - - del vllm_model + with vllm_runner(model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.3, + served_model_name=served_model_name) as vllm_model: + stat_logger = vllm_model.model.llm_engine.stat_logger + metrics_tag_content = stat_logger.labels["model_name"] if served_model_name is None or served_model_name == []: assert metrics_tag_content == model, ( diff --git a/tests/models/test_aqlm.py b/tests/models/test_aqlm.py index 85d74f7f5b03d..c4ecf846e633c 100644 --- a/tests/models/test_aqlm.py +++ b/tests/models/test_aqlm.py @@ -82,10 +82,9 @@ def test_models( num_logprobs: int, ) -> None: - vllm_model = vllm_runner(model, dtype=dtype) - vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, - max_tokens, - num_logprobs) + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) # loop through the prompts to compare against the ground truth generations for prompt_idx in range(len(example_prompts)): diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index ea95e6a49f03a..ef78283731775 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -34,13 +34,11 @@ def test_models( dtype: str, max_tokens: int, ) -> None: - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - del hf_model + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - vllm_model = vllm_runner(model, dtype=dtype) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - del vllm_model + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] @@ -58,9 +56,8 @@ def test_model_print( model: str, dtype: str, ) -> None: - vllm_model = vllm_runner(model, dtype=dtype) - # This test is for verifying whether the model's extra_repr - # can be printed correctly. - print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - del vllm_model + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) diff --git a/tests/models/test_embedding.py b/tests/models/test_embedding.py index 668ed3a520a36..6556998b68a74 100644 --- a/tests/models/test_embedding.py +++ b/tests/models/test_embedding.py @@ -28,13 +28,11 @@ def test_models( model: str, dtype: str, ) -> None: - hf_model = hf_runner(model, dtype=dtype, is_embedding_model=True) - hf_outputs = hf_model.encode(example_prompts) - del hf_model + with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model: + hf_outputs = hf_model.encode(example_prompts) - vllm_model = vllm_runner(model, dtype=dtype) - vllm_outputs = vllm_model.encode(example_prompts) - del vllm_model + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) similarities = compare_embeddings(hf_outputs, vllm_outputs) all_similarities = torch.stack(similarities) diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py index 814471b47763d..e957450cce97b 100644 --- a/tests/models/test_gptq_marlin.py +++ b/tests/models/test_gptq_marlin.py @@ -70,32 +70,29 @@ def test_models( model_name, revision = model # Run marlin. - gptq_marlin_model = vllm_runner(model_name=model_name, - revision=revision, - dtype=dtype, - quantization="marlin", - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1) - - gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( - example_prompts[:-1], max_tokens, num_logprobs) - del gptq_marlin_model + with vllm_runner(model_name=model_name, + revision=revision, + dtype=dtype, + quantization="marlin", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1) as gptq_marlin_model: + + gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( + example_prompts[:-1], max_tokens, num_logprobs) _ROPE_DICT.clear() # clear rope cache to avoid rope dtype error # Run gptq. # The naive gptq kernel doesn't support bf16 yet. # Here we always compare fp16/bf16 gpt marlin kernel # to fp16 gptq kernel. - gptq_model = vllm_runner(model_name=model_name, - revision=revision, - dtype="half", - quantization="gptq", - max_model_len=MAX_MODEL_LEN, - tensor_parallel_size=1) - gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts[:-1], - max_tokens, - num_logprobs) - del gptq_model + with vllm_runner(model_name=model_name, + revision=revision, + dtype="half", + quantization="gptq", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1) as gptq_model: + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts[:-1], max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=gptq_outputs, diff --git a/tests/models/test_gptq_marlin_24.py b/tests/models/test_gptq_marlin_24.py index cc35ee803ff01..195c3e5b5863e 100644 --- a/tests/models/test_gptq_marlin_24.py +++ b/tests/models/test_gptq_marlin_24.py @@ -61,20 +61,16 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - marlin_24_model = vllm_runner(model_pair.model_marlin, - dtype=dtype, - quantization="gptq_marlin_24") - marlin_24_outputs = marlin_24_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - del marlin_24_model + with vllm_runner(model_pair.model_marlin, + dtype=dtype, + quantization="gptq_marlin_24") as marlin_24_model: + marlin_24_outputs = marlin_24_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - gptq_model = vllm_runner(model_pair.model_gptq, - dtype=dtype, - quantization="gptq") - gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, - max_tokens, - num_logprobs) - del gptq_model + with vllm_runner(model_pair.model_gptq, dtype=dtype, + quantization="gptq") as gptq_model: + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=gptq_outputs, diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index f03dbdbb770e5..a1f0cff1cc0e5 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -39,8 +39,6 @@ def iter_llava_configs(model_name: str): model_and_vl_config = [ *iter_llava_configs("llava-hf/llava-1.5-7b-hf"), - # Not enough memory - # *iter_llava_configs("llava-hf/llava-1.5-13b-hf"), ] @@ -84,25 +82,23 @@ def test_models(hf_runner, vllm_runner, hf_images, vllm_images, """ model_id, vlm_config = model_and_config - hf_model = hf_runner(model_id, dtype=dtype, is_vision_model=True) - hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, - max_tokens, - images=hf_images) - del hf_model + with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: + hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, + max_tokens, + images=hf_images) vllm_image_prompts = [ p.replace("", "" * vlm_config.image_feature_size) for p in HF_IMAGE_PROMPTS ] - vllm_model = vllm_runner(model_id, - dtype=dtype, - enforce_eager=True, - **vlm_config.as_cli_args_dict()) - vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, - max_tokens, - images=vllm_images) - del vllm_model + with vllm_runner(model_id, + dtype=dtype, + enforce_eager=True, + **vlm_config.as_cli_args_dict()) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, + max_tokens, + images=vllm_images) for i in range(len(HF_IMAGE_PROMPTS)): hf_output_ids, hf_output_str = hf_outputs[i] diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py new file mode 100644 index 0000000000000..aa6ee268ae588 --- /dev/null +++ b/tests/models/test_llava_next.py @@ -0,0 +1,123 @@ +from typing import List, Tuple + +import pytest +from transformers import AutoTokenizer + +from vllm.config import VisionLanguageConfig + +from ..conftest import IMAGE_FILES + +pytestmark = pytest.mark.llava + +_PREFACE = ( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's " + "questions.") + +# The image token is placed before "user" on purpose so that the test can pass +HF_IMAGE_PROMPTS = [ + f"{_PREFACE} \nUSER: What's the content of the image? ASSISTANT:", + f"{_PREFACE} \nUSER: What is the season? ASSISTANT:", +] + +assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES) + + +def iter_llava_next_configs(model_name: str): + image_hw_to_feature_size = { + (336, 336): 1176, + (672, 672): 2928, + (1344, 336): 1944, + (336, 1344): 1890, + } + + for (h, w), f in image_hw_to_feature_size.items(): + for input_type, input_shape in [ + (VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)), + ]: + yield (model_name, + VisionLanguageConfig(image_input_type=input_type, + image_feature_size=f, + image_token_id=32000, + image_input_shape=input_shape, + image_processor=model_name, + image_processor_revision=None)) + + +model_and_vl_config = [ + *iter_llava_next_configs("llava-hf/llava-v1.6-vicuna-7b-hf"), +] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str], + vlm_config: VisionLanguageConfig, model_id: str): + """Sanitize vllm output to be comparable with hf output. + The function reduces `input_ids` from 1, 32000, 32000, ..., 32000, + x1, x2, x3 ... to 1, 32000, x1, x2, x3 ... + It also reduces `output_str` from "bla" to "bla". + """ + input_ids, output_str = vllm_output + image_token_id = vlm_config.image_token_id + + tokenizer = AutoTokenizer.from_pretrained(model_id) + image_token_str = tokenizer.decode(image_token_id) + + hf_input_ids = [ + input_id for idx, input_id in enumerate(input_ids) + if input_id != image_token_id or input_ids[idx - 1] != image_token_id + ] + hf_output_str = output_str \ + .replace(image_token_str * vlm_config.image_feature_size, " ") + + return hf_input_ids, hf_output_str + + +@pytest.mark.xfail( + reason="Inconsistent image processor being used due to lack " + "of support for dynamic image token replacement") +@pytest.mark.parametrize("model_and_config", model_and_vl_config) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_models(hf_runner, vllm_runner, hf_images, vllm_images, + model_and_config, dtype: str, max_tokens: int) -> None: + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalData objects and corresponding + vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + model_id, vlm_config = model_and_config + + with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: + hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, + max_tokens, + images=hf_images) + + vllm_image_prompts = [ + p.replace("", "" * vlm_config.image_feature_size) + for p in HF_IMAGE_PROMPTS + ] + + with vllm_runner( + model_id, + dtype=dtype, + # should be greater than image_feature_size + max_model_len=4096, + enforce_eager=True, + **vlm_config.as_cli_args_dict(), + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, + max_tokens, + images=vllm_images) + + for i in range(len(HF_IMAGE_PROMPTS)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_to_hf_output( + vllm_outputs[i], vlm_config, model_id) + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 8520b26718bf5..761ba6aa4d592 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -59,20 +59,16 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - marlin_model = vllm_runner(model_pair.model_marlin, - dtype=dtype, - quantization="marlin") - marlin_outputs = marlin_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - del marlin_model - - gptq_model = vllm_runner(model_pair.model_gptq, - dtype=dtype, - quantization="gptq") - gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, - max_tokens, - num_logprobs) - del gptq_model + with vllm_runner(model_pair.model_marlin, + dtype=dtype, + quantization="marlin") as marlin_model: + marlin_outputs = marlin_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model_pair.model_gptq, dtype=dtype, + quantization="gptq") as gptq_model: + gptq_outputs = gptq_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=gptq_outputs, diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 76b248cf14e98..6acc057fe588c 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -26,16 +26,13 @@ def test_models( num_logprobs: int, ) -> None: # TODO(sang): Sliding window should be tested separately. - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - del hf_model + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) - vllm_model = vllm_runner(model, dtype=dtype) - vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, - max_tokens, - num_logprobs) - del vllm_model + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, diff --git a/tests/models/test_models.py b/tests/models/test_models.py index e4609620387fa..71238d6909a69 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -34,13 +34,11 @@ def test_models( # To pass the small model tests, we need full precision. assert dtype == "float" - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - del hf_model + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - vllm_model = vllm_runner(model, dtype=dtype) - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - del vllm_model + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) for i in range(len(example_prompts)): hf_output_ids, hf_output_str = hf_outputs[i] @@ -58,9 +56,8 @@ def test_model_print( model: str, dtype: str, ) -> None: - vllm_model = vllm_runner(model, dtype=dtype) - # This test is for verifying whether the model's extra_repr - # can be printed correctly. - print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) - del vllm_model + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) diff --git a/tests/multimodal/test_processor.py b/tests/multimodal/test_processor.py index 3df28e782dd89..51c352361702a 100644 --- a/tests/multimodal/test_processor.py +++ b/tests/multimodal/test_processor.py @@ -1,6 +1,6 @@ import numpy as np import pytest -from transformers import CLIPImageProcessor +from transformers import CLIPImageProcessor, LlavaNextImageProcessor from vllm.config import ModelConfig, VisionLanguageConfig from vllm.multimodal import MULTIMODAL_REGISTRY @@ -12,7 +12,7 @@ @pytest.mark.parametrize("dtype", ["half", "float"]) def test_clip_image_processor(hf_images, dtype): MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - IMAGE_HEIGHT = IMAGE_WIDTH = 33 + IMAGE_HEIGHT = IMAGE_WIDTH = 560 hf_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME) assert isinstance(hf_processor, CLIPImageProcessor) @@ -55,10 +55,61 @@ def test_clip_image_processor(hf_images, dtype): assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" +@pytest.mark.xfail( + reason="Inconsistent image processor being used due to lack " + "of support for dynamic image token replacement") +@pytest.mark.parametrize("dtype", ["half", "float"]) +def test_llava_next_image_processor(hf_images, dtype): + MODEL_NAME = "llava-hf/llava-v1.6-34b-hf" + IMAGE_HEIGHT = IMAGE_WIDTH = 560 + + hf_processor = LlavaNextImageProcessor.from_pretrained(MODEL_NAME) + assert isinstance(hf_processor, LlavaNextImageProcessor) + + model_config = ModelConfig( + model=MODEL_NAME, + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype=dtype, + revision=None, + ) + vlm_config = VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_token_id=64000, + image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), + image_feature_size=2928, + image_processor=MODEL_NAME, + image_processor_revision=None, + ) + + for image in hf_images: + hf_result = hf_processor.preprocess( + image, + return_tensors="pt", + ).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype]) + vllm_result = MULTIMODAL_REGISTRY.process_input( + ImagePixelData(image), + model_config=model_config, + vlm_config=vlm_config, + ) + + assert hf_result.keys() == vllm_result.keys() + for key, hf_tensor in hf_result.items(): + hf_arr: np.ndarray = hf_tensor.numpy() + vllm_arr: np.ndarray = vllm_result[key].numpy() + + assert hf_arr.shape == vllm_arr.shape, f"Failed for key={key}" + assert np.allclose(hf_arr, vllm_arr), f"Failed for key={key}" + + +@pytest.mark.xfail( + reason="Example image pixels were not processed using HuggingFace") @pytest.mark.parametrize("dtype", ["float"]) def test_image_pixel_types(hf_images, vllm_image_tensors, dtype): MODEL_NAME = "llava-hf/llava-1.5-7b-hf" - IMAGE_HEIGHT = IMAGE_WIDTH = 33 + IMAGE_HEIGHT = IMAGE_WIDTH = 560 model_config = ModelConfig( model=MODEL_NAME, @@ -95,7 +146,4 @@ def test_image_pixel_types(hf_images, vllm_image_tensors, dtype): tensor_arr: np.ndarray = tensor_result[key].numpy() assert image_arr.shape == tensor_arr.shape, f"Failed for key={key}" - - # The examples in PR#3042 have slightly different preprocessing from - # HuggingFace's LlavaProcessor, causing the test to fail. - # assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}" + assert np.allclose(image_arr, tensor_arr), f"Failed for key={key}" diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py new file mode 100644 index 0000000000000..5a6395ac9e42a --- /dev/null +++ b/tests/multimodal/test_utils.py @@ -0,0 +1,75 @@ +import base64 +import mimetypes +from tempfile import NamedTemporaryFile +from typing import Dict, Tuple + +import numpy as np +import pytest +import pytest_asyncio +from PIL import Image + +from vllm.multimodal.utils import ImageFetchAiohttp + +# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) +TEST_IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +] + + +@pytest_asyncio.fixture(scope="session") +async def url_images() -> Dict[str, Image.Image]: + return { + image_url: await ImageFetchAiohttp.fetch_image(image_url) + for image_url in TEST_IMAGE_URLS + } + + +def get_supported_suffixes() -> Tuple[str, ...]: + # We should at least test the file types mentioned in GPT-4 with Vision + OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif') + + # Additional file types that are supported by us + EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff') + + return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES + + +def _image_equals(a: Image.Image, b: Image.Image) -> bool: + return (np.asarray(a) == np.asarray(b.convert(a.mode))).all() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize("suffix", get_supported_suffixes()) +async def test_fetch_image_base64(url_images: Dict[str, Image.Image], + image_url: str, suffix: str): + url_image = url_images[image_url] + + try: + mime_type = Image.MIME[Image.registered_extensions()[suffix]] + except KeyError: + try: + mime_type = mimetypes.types_map[suffix] + except KeyError: + pytest.skip('No MIME type') + + with NamedTemporaryFile(suffix=suffix) as f: + try: + url_image.save(f.name) + except Exception as e: + if e.args[0] == 'cannot write mode RGBA as JPEG': + pytest.skip('Conversion not supported') + + raise + + base64_image = base64.b64encode(f.read()).decode("utf-8") + data_url = f"data:{mime_type};base64,{base64_image}" + + data_image = await ImageFetchAiohttp.fetch_image(data_url) + if _image_equals(url_image, Image.open(f)): + assert _image_equals(url_image, data_image) + else: + pass # Lossy format; only check that image can be opened diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 4e9feb3c48148..31e938d15a1f6 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -16,65 +16,65 @@ capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(), reason='bitsandbytes is not supported on this GPU type.') def test_load_bnb_model(vllm_runner) -> None: - llm = vllm_runner('huggyllama/llama-7b', - quantization='bitsandbytes', - load_format='bitsandbytes', - enforce_eager=True) - - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model - - # check the weights in MLP & SelfAttention are quantized to torch.uint8 - qweight = model.model.layers[0].mlp.gate_up_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].mlp.down_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].self_attn.o_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].self_attn.qkv_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}') - - # some weights should not be quantized - weight = model.lm_head.weight - assert weight.dtype != torch.uint8, ( - 'lm_head weight dtype should not be torch.uint8') - - weight = model.model.embed_tokens.weight - assert weight.dtype != torch.uint8, ( - 'embed_tokens weight dtype should not be torch.uint8') - - weight = model.model.layers[0].input_layernorm.weight - assert weight.dtype != torch.uint8, ( - 'input_layernorm weight dtype should not be torch.uint8') - - weight = model.model.layers[0].post_attention_layernorm.weight - assert weight.dtype != torch.uint8, ( - 'input_layernorm weight dtype should not be torch.uint8') - - # check the output of the model is expected - sampling_params = SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=8) - - prompts = ['That which does not kill us', 'To be or not to be,'] - expected_outputs = [ - 'That which does not kill us makes us stronger.', - 'To be or not to be, that is the question.' - ] - outputs = llm.generate(prompts, sampling_params=sampling_params) - - assert len(outputs) == len(prompts) - - for index in range(len(outputs)): - # compare the first line of the output - actual_output = outputs[index][1][0].split('\n', 1)[0] - expected_output = expected_outputs[index].split('\n', 1)[0] - assert actual_output == expected_output, ( - f'Expected: {expected_output}, but got: {actual_output}') + with vllm_runner('huggyllama/llama-7b', + quantization='bitsandbytes', + load_format='bitsandbytes', + enforce_eager=True) as llm: + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + + # check the weights in MLP & SelfAttention are quantized to torch.uint8 + qweight = model.model.layers[0].mlp.gate_up_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}') + + qweight = model.model.layers[0].mlp.down_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}') + + qweight = model.model.layers[0].self_attn.o_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}') + + qweight = model.model.layers[0].self_attn.qkv_proj.qweight + assert qweight.dtype == torch.uint8, ( + f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}') + + # some weights should not be quantized + weight = model.lm_head.weight + assert weight.dtype != torch.uint8, ( + 'lm_head weight dtype should not be torch.uint8') + + weight = model.model.embed_tokens.weight + assert weight.dtype != torch.uint8, ( + 'embed_tokens weight dtype should not be torch.uint8') + + weight = model.model.layers[0].input_layernorm.weight + assert weight.dtype != torch.uint8, ( + 'input_layernorm weight dtype should not be torch.uint8') + + weight = model.model.layers[0].post_attention_layernorm.weight + assert weight.dtype != torch.uint8, ( + 'input_layernorm weight dtype should not be torch.uint8') + + # check the output of the model is expected + sampling_params = SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=8) + + prompts = ['That which does not kill us', 'To be or not to be,'] + expected_outputs = [ + 'That which does not kill us makes us stronger.', + 'To be or not to be, that is the question.' + ] + outputs = llm.generate(prompts, sampling_params=sampling_params) + + assert len(outputs) == len(prompts) + + for index in range(len(outputs)): + # compare the first line of the output + actual_output = outputs[index][1][0].split('\n', 1)[0] + expected_output = expected_outputs[index].split('\n', 1)[0] + assert actual_output == expected_output, ( + f'Expected: {expected_output}, but got: {actual_output}') diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 8b48f418fe49f..e6d8218b41372 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -5,49 +5,58 @@ import torch +from vllm import SamplingParams from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) def test_compressed_tensors_w8a8_static_setup(vllm_runner): - model_path = "nm-testing/tinyllama-one-shot-static-quant-test-compressed" - llm = vllm_runner(model_path, quantization="sparseml", enforce_eager=True) - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model - layer = model.model.layers[0] + model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2" + with vllm_runner(model_path, enforce_eager=True) as llm: + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + layer = model.model.layers[0] - qkv_proj = layer.self_attn.qkv_proj - o_proj = layer.self_attn.o_proj - gate_up_proj = layer.mlp.gate_up_proj - down_proj = layer.mlp.down_proj + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj - assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(gate_up_proj.quant_method, + CompressedTensorsLinearMethod) + assert isinstance(down_proj.quant_method, + CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor) - assert qkv_proj.weight.dtype is torch.int8 - assert o_proj.weight.dtype is torch.int8 - assert gate_up_proj.weight.dtype is torch.int8 + assert qkv_proj.weight.dtype is torch.int8 + assert o_proj.weight.dtype is torch.int8 + assert gate_up_proj.weight.dtype is torch.int8 - assert qkv_proj.weight_scale.shard_splitter is not None - assert qkv_proj.weight_scale.logical_widths is not None - assert qkv_proj.input_scale.dtype is torch.float32 + assert qkv_proj.weight_scale.shard_splitter is not None + assert qkv_proj.weight_scale.logical_widths is not None + assert qkv_proj.input_scale.dtype is torch.float32 + + +def test_compressed_tensors_no_enforce_eager(vllm_runner): + model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2" + with vllm_runner(model_path) as llm: + sampling_params = SamplingParams() + output = llm.generate("Hello world!", sampling_params=sampling_params) + assert output def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner): - model_path = "nm-testing/tinyllama-one-shot-dynamic-test" - llm = vllm_runner(model_path, - quantization="sparseml", - enforce_eager=True, - dtype=torch.float16) - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model - layer = model.model.layers[0] - - qkv_proj = layer.self_attn.qkv_proj - - assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) - assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken) - assert qkv_proj.weight.dtype is torch.int8 + model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2" + with vllm_runner(model_path, enforce_eager=True, + dtype=torch.float16) as llm: + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken) + assert qkv_proj.weight.dtype is torch.int8 diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 607544a1c8394..fccce7f7b59a7 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -16,9 +16,9 @@ capability < QUANTIZATION_METHODS["fp8"].get_min_capability(), reason="FP8 is not supported on this GPU type.") def test_load_fp16_model(vllm_runner) -> None: - llm = vllm_runner("facebook/opt-125m", quantization="fp8") + with vllm_runner("facebook/opt-125m", quantization="fp8") as llm: - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model - fc1 = model.model.decoder.layers[0].fc1 - assert isinstance(fc1.quant_method, Fp8LinearMethod) - assert fc1.weight.dtype == torch.float8_e4m3fn + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + fc1 = model.model.decoder.layers[0].fc1 + assert isinstance(fc1.quant_method, Fp8LinearMethod) + assert fc1.weight.dtype == torch.float8_e4m3fn diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 2682f284505bd..64f3ce94b7a83 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -2,10 +2,8 @@ Run `pytest tests/samplers/test_beam_search.py`. """ -import gc import pytest -import torch # FIXME(zhuohan): The test can not pass if we: # 1. Increase max_tokens to 256. @@ -30,19 +28,13 @@ def test_beam_search_single_input( beam_width: int, ) -> None: example_prompts = example_prompts[:1] - hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, - max_tokens) - del hf_model - - vllm_model = vllm_runner(model, dtype=dtype) - vllm_outputs = vllm_model.generate_beam_search(example_prompts, beam_width, + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, max_tokens) - del vllm_model - # NOTE(woosuk): For some reason, the following GC is required to avoid - # GPU OOM errors in the following tests using `vllm_runner`. - gc.collect() - torch.cuda.empty_cache() + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_beam_search(example_prompts, + beam_width, max_tokens) for i in range(len(example_prompts)): hf_output_ids, _ = hf_outputs[i] diff --git a/tests/samplers/test_ignore_eos.py b/tests/samplers/test_ignore_eos.py index 67b5168bea0e6..dc2482d85a91f 100644 --- a/tests/samplers/test_ignore_eos.py +++ b/tests/samplers/test_ignore_eos.py @@ -22,11 +22,12 @@ def test_ignore_eos( dtype: str, max_tokens: int, ) -> None: - vllm_model = vllm_runner(model, dtype=dtype) - sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) + with vllm_runner(model, dtype=dtype) as vllm_model: + sampling_params = SamplingParams(max_tokens=max_tokens, + ignore_eos=True) - for prompt in example_prompts: - ignore_eos_output = vllm_model.model.generate( - prompt, sampling_params=sampling_params) - output_length = len(ignore_eos_output[0].outputs[0].token_ids) - assert output_length == max_tokens + for prompt in example_prompts: + ignore_eos_output = vllm_model.model.generate( + prompt, sampling_params=sampling_params) + output_length = len(ignore_eos_output[0].outputs[0].token_ids) + assert output_length == max_tokens diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 0ccbabfff6403..2979470120710 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -14,46 +14,46 @@ def test_logits_processor_force_generate( model: str, dtype: str, ) -> None: - vllm_model = vllm_runner(model, dtype=dtype) - tokenizer = vllm_model.model.get_tokenizer() - repeat_times = 2 - enforced_answers = " vLLM" - vllm_token_ids = tokenizer.encode(enforced_answers, - add_special_tokens=False) - max_tokens = len(vllm_token_ids) * repeat_times - - def pick_vllm(token_ids, logits): - token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)] - logits[token_id] = torch.finfo(logits.dtype).max - return logits - - params_with_logprobs = SamplingParams( - logits_processors=[pick_vllm], - prompt_logprobs=3, - max_tokens=max_tokens, - ) - - # test logits_processors when prompt_logprobs is not None - vllm_model.model._add_request( - example_prompts[0], - params=params_with_logprobs, - ) - - # test prompt_logprobs is not None - vllm_model.model._add_request( - example_prompts[1], - params=SamplingParams( + with vllm_runner(model, dtype=dtype) as vllm_model: + tokenizer = vllm_model.model.get_tokenizer() + repeat_times = 2 + enforced_answers = " vLLM" + vllm_token_ids = tokenizer.encode(enforced_answers, + add_special_tokens=False) + max_tokens = len(vllm_token_ids) * repeat_times + + def pick_vllm(token_ids, logits): + token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)] + logits[token_id] = torch.finfo(logits.dtype).max + return logits + + params_with_logprobs = SamplingParams( + logits_processors=[pick_vllm], prompt_logprobs=3, max_tokens=max_tokens, - ), - ) - - # test grouped requests - vllm_model.model._add_request( - example_prompts[2], - params=SamplingParams(max_tokens=max_tokens), - ) - - outputs = vllm_model.model._run_engine(use_tqdm=False) - - assert outputs[0].outputs[0].text == enforced_answers * repeat_times + ) + + # test logits_processors when prompt_logprobs is not None + vllm_model.model._add_request( + example_prompts[0], + params=params_with_logprobs, + ) + + # test prompt_logprobs is not None + vllm_model.model._add_request( + example_prompts[1], + params=SamplingParams( + prompt_logprobs=3, + max_tokens=max_tokens, + ), + ) + + # test grouped requests + vllm_model.model._add_request( + example_prompts[2], + params=SamplingParams(max_tokens=max_tokens), + ) + + outputs = vllm_model.model._run_engine(use_tqdm=False) + + assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 61720cccf50b4..233540cdc391f 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -32,28 +32,27 @@ def test_get_prompt_logprobs( max_num_batched_tokens = chunked_prefill_token_size max_tokens = 5 - hf_model = hf_runner(model, dtype=dtype) - hf_logprobs = hf_model.generate_greedy_logprobs( - example_prompts, - max_tokens=max_tokens, - ) - del hf_model - - vllm_model = vllm_runner( - model, - dtype=dtype, - max_logprobs=num_top_logprobs, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs, - ) - vllm_sampling_params = SamplingParams(max_tokens=max_tokens, - logprobs=num_top_logprobs, - prompt_logprobs=num_top_logprobs, - temperature=0.0, - detokenize=detokenize) - vllm_results = vllm_model.model.generate( - example_prompts, sampling_params=vllm_sampling_params) + with hf_runner(model, dtype=dtype) as hf_model: + hf_logprobs = hf_model.generate_greedy_logprobs( + example_prompts, + max_tokens=max_tokens, + ) + + with vllm_runner( + model, + dtype=dtype, + max_logprobs=num_top_logprobs, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs, + ) as vllm_model: + vllm_sampling_params = SamplingParams(max_tokens=max_tokens, + logprobs=num_top_logprobs, + prompt_logprobs=num_top_logprobs, + temperature=0.0, + detokenize=detokenize) + vllm_results = vllm_model.model.generate( + example_prompts, sampling_params=vllm_sampling_params) # Test whether logprobs are included in the results. for result in vllm_results: diff --git a/tests/samplers/test_ranks.py b/tests/samplers/test_ranks.py index 5e93238d709ec..ed2fee1ae252e 100644 --- a/tests/samplers/test_ranks.py +++ b/tests/samplers/test_ranks.py @@ -17,16 +17,27 @@ def test_ranks( num_top_logprobs = 5 num_prompt_logprobs = 5 - vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs) - - ## Test greedy logprobs ranks - vllm_sampling_params = SamplingParams(temperature=0.0, - top_p=1.0, - max_tokens=max_tokens, - logprobs=num_top_logprobs, - prompt_logprobs=num_prompt_logprobs) - vllm_results = vllm_model.generate_w_logprobs(example_prompts, - vllm_sampling_params) + with vllm_runner(model, dtype=dtype, + max_logprobs=num_top_logprobs) as vllm_model: + + ## Test greedy logprobs ranks + vllm_sampling_params = SamplingParams( + temperature=0.0, + top_p=1.0, + max_tokens=max_tokens, + logprobs=num_top_logprobs, + prompt_logprobs=num_prompt_logprobs) + vllm_results = vllm_model.generate_w_logprobs(example_prompts, + vllm_sampling_params) + + ## Test non-greedy logprobs ranks + sampling_params = SamplingParams(temperature=1.0, + top_p=1.0, + max_tokens=max_tokens, + logprobs=num_top_logprobs, + prompt_logprobs=num_prompt_logprobs) + res = vllm_model.generate_w_logprobs(example_prompts, sampling_params) + for result in vllm_results: assert result[2] is not None assert len(result[2]) == len(result[0]) @@ -35,13 +46,6 @@ def test_ranks( assert token in logprobs assert logprobs[token].rank == 1 - ## Test non-greedy logprobs ranks - sampling_params = SamplingParams(temperature=1.0, - top_p=1.0, - max_tokens=max_tokens, - logprobs=num_top_logprobs, - prompt_logprobs=num_prompt_logprobs) - res = vllm_model.generate_w_logprobs(example_prompts, sampling_params) for result in res: assert result[2] is not None assert len(result[2]) == len(result[0]) diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index fef5ff3fb9e8e..88067f19c8f07 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -17,9 +17,8 @@ @pytest.fixture def vllm_model(vllm_runner): - vllm_model = vllm_runner(MODEL, dtype="half") - yield vllm_model - del vllm_model + with vllm_runner(MODEL, dtype="half") as vllm_model: + yield vllm_model @pytest.mark.parametrize("seed", RANDOM_SEEDS) diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index 1579d53a7fe29..b558bfc6df21b 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -1,4 +1,3 @@ -import gc import json import os import subprocess @@ -7,7 +6,6 @@ import openai import pytest import ray -import torch from vllm import SamplingParams # yapf: disable @@ -71,72 +69,66 @@ def test_can_deserialize_s3(vllm_runner): model_ref = "EleutherAI/pythia-1.4b" tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" - loaded_hf_model = vllm_runner(model_ref, + with vllm_runner(model_ref, load_format="tensorizer", model_loader_extra_config=TensorizerConfig( tensorizer_uri=tensorized_path, num_readers=1, s3_endpoint="object.ord1.coreweave.com", - )) + )) as loaded_hf_model: - deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) + deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params) # noqa: E501 - assert deserialized_outputs + assert deserialized_outputs @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_deserialized_encrypted_vllm_model_has_same_outputs( vllm_runner, tmp_path): - vllm_model = vllm_runner(model_ref) - model_path = tmp_path / (model_ref + ".tensors") - key_path = tmp_path / (model_ref + ".key") - outputs = vllm_model.generate(prompts, sampling_params) - - config_for_serializing = TensorizerConfig(tensorizer_uri=model_path) - serialize_vllm_model(vllm_model.model.llm_engine, - config_for_serializing, - encryption_key_path=key_path) + with vllm_runner(model_ref) as vllm_model: + model_path = tmp_path / (model_ref + ".tensors") + key_path = tmp_path / (model_ref + ".key") + outputs = vllm_model.generate(prompts, sampling_params) - del vllm_model - gc.collect() - torch.cuda.empty_cache() + config_for_serializing = TensorizerConfig(tensorizer_uri=model_path) + serialize_vllm_model(vllm_model.model.llm_engine, + config_for_serializing, + encryption_key_path=key_path) config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path, encryption_keyfile=key_path) - loaded_vllm_model = vllm_runner( + with vllm_runner( model_ref, load_format="tensorizer", - model_loader_extra_config=config_for_deserializing) + model_loader_extra_config=config_for_deserializing) as loaded_vllm_model: # noqa: E501 - deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 - assert outputs == deserialized_outputs + assert outputs == deserialized_outputs def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, tmp_path): - hf_model = hf_runner(model_ref) - model_path = tmp_path / (model_ref + ".tensors") - max_tokens = 50 - outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) - with open_stream(model_path, "wb+") as stream: - serializer = TensorSerializer(stream) - serializer.write_module(hf_model.model) - del hf_model - gc.collect() - torch.cuda.empty_cache() - loaded_hf_model = vllm_runner(model_ref, + with hf_runner(model_ref) as hf_model: + model_path = tmp_path / (model_ref + ".tensors") + max_tokens = 50 + outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) + with open_stream(model_path, "wb+") as stream: + serializer = TensorSerializer(stream) + serializer.write_module(hf_model.model) + + with vllm_runner(model_ref, load_format="tensorizer", model_loader_extra_config=TensorizerConfig( tensorizer_uri=model_path, num_readers=1, - )) + )) as loaded_hf_model: - deserialized_outputs = loaded_hf_model.generate_greedy( - prompts, max_tokens=max_tokens) + deserialized_outputs = loaded_hf_model.generate_greedy( + prompts, max_tokens=max_tokens) - assert outputs == deserialized_outputs + assert outputs == deserialized_outputs def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): @@ -150,16 +142,13 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): test_prompts = create_test_prompts(lora_path) # Serialize model before deserializing and binding LoRA adapters - vllm_model = vllm_runner(model_ref, ) - model_path = tmp_path / (model_ref + ".tensors") + with vllm_runner(model_ref, ) as vllm_model: + model_path = tmp_path / (model_ref + ".tensors") - serialize_vllm_model(vllm_model.model.llm_engine, - TensorizerConfig(tensorizer_uri=model_path)) + serialize_vllm_model(vllm_model.model.llm_engine, + TensorizerConfig(tensorizer_uri=model_path)) - del vllm_model - gc.collect() - torch.cuda.empty_cache() - loaded_vllm_model = vllm_runner( + with vllm_runner( model_ref, load_format="tensorizer", model_loader_extra_config=TensorizerConfig( @@ -172,10 +161,10 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path): max_cpu_loras=2, max_num_seqs=50, max_model_len=1000, - ) - process_requests(loaded_vllm_model.model.llm_engine, test_prompts) + ) as loaded_vllm_model: + process_requests(loaded_vllm_model.model.llm_engine, test_prompts) - assert loaded_vllm_model + assert loaded_vllm_model def test_load_without_tensorizer_load_format(vllm_runner): @@ -188,19 +177,15 @@ def test_load_without_tensorizer_load_format(vllm_runner): @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path): ## Serialize model - vllm_model = vllm_runner(model_ref, ) - model_path = tmp_path / (model_ref + ".tensors") - - serialize_vllm_model(vllm_model.model.llm_engine, - TensorizerConfig(tensorizer_uri=model_path)) + with vllm_runner(model_ref, ) as vllm_model: + model_path = tmp_path / (model_ref + ".tensors") - model_loader_extra_config = { - "tensorizer_uri": str(model_path), - } + serialize_vllm_model(vllm_model.model.llm_engine, + TensorizerConfig(tensorizer_uri=model_path)) - del vllm_model - gc.collect() - torch.cuda.empty_cache() + model_loader_extra_config = { + "tensorizer_uri": str(model_path), + } ## Start OpenAI API server openai_args = [ @@ -262,18 +247,15 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): model_path = tmp_path / (model_ref + ".tensors") config = TensorizerConfig(tensorizer_uri=str(model_path)) - vllm_model = vllm_runner(model_ref) - outputs = vllm_model.generate(prompts, sampling_params) - serialize_vllm_model(vllm_model.model.llm_engine, config) + with vllm_runner(model_ref) as vllm_model: + outputs = vllm_model.generate(prompts, sampling_params) + serialize_vllm_model(vllm_model.model.llm_engine, config) - assert is_vllm_tensorized(config) - del vllm_model - gc.collect() - torch.cuda.empty_cache() + assert is_vllm_tensorized(config) - loaded_vllm_model = vllm_runner(model_ref, - load_format="tensorizer", - model_loader_extra_config=config) - deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) + with vllm_runner(model_ref, + load_format="tensorizer", + model_loader_extra_config=config) as loaded_vllm_model: + deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params) # noqa: E501 - assert outputs == deserialized_outputs + assert outputs == deserialized_outputs diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index 022fb36b346f4..de79c3b945d4d 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -39,7 +39,8 @@ def test_filter_subtensors(): filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict) assert tuple(filtered_state_dict.keys()) == ("a", "b", "c") for key, tensor in filtered_state_dict.items(): - assert tensor.equal(state_dict[key]) + # NOTE: don't use `euqal` here, as the tensor might contain NaNs + assert tensor is state_dict[key] @pytest.fixture(scope="module") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7e12f1ba14cde..440b0e8afa99a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,35 +1,47 @@ -from typing import Optional, Tuple, Type +import contextlib +from typing import List, Optional, Tuple, Type import torch try: - from vllm._C import cache_ops as vllm_cache_ops - from vllm._C import ops as vllm_ops + import vllm._C except ImportError as e: from vllm.logger import init_logger logger = init_logger(__name__) logger.warning("Failed to import from vllm._C with %r", e) +with contextlib.suppress(ImportError): + import vllm._moe_C + +with contextlib.suppress(ImportError): + # ruff: noqa: F401 + import vllm._punica_C + + +def is_custom_op_supported(op_name: str) -> bool: + op, overloads = torch._C._jit_get_operation(op_name) + return op is not None + # activation ops def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.silu_and_mul(out, x) + torch.ops._C.silu_and_mul(out, x) def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_and_mul(out, x) + torch.ops._C.gelu_and_mul(out, x) def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_tanh_and_mul(out, x) + torch.ops._C.gelu_tanh_and_mul(out, x) def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_fast(out, x) + torch.ops._C.gelu_fast(out, x) def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: - vllm_ops.gelu_new(out, x) + torch.ops._C.gelu_new(out, x) # page attention ops @@ -53,7 +65,7 @@ def paged_attention_v1( blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: - vllm_ops.paged_attention_v1( + torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, @@ -83,7 +95,7 @@ def paged_attention_v2( blocksparse_block_size: int = 64, blocksparse_head_sliding_step: int = 0, ) -> None: - vllm_ops.paged_attention_v2( + torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, kv_scale, tp_rank, @@ -100,8 +112,8 @@ def rotary_embedding( cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: - vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache, - is_neox) + torch.ops._C.rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox) def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, @@ -109,20 +121,20 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, cos_sin_cache: torch.Tensor, is_neox: bool, rot_dim: int, cos_sin_cache_offsets: torch.Tensor) -> None: - vllm_ops.batched_rotary_embedding(positions, query, key, head_size, - cos_sin_cache, is_neox, rot_dim, - cos_sin_cache_offsets) + torch.ops._C.batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) # layer norm ops def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: - vllm_ops.rms_norm(out, input, weight, epsilon) + torch.ops._C.rms_norm(out, input, weight, epsilon) def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float) -> None: - vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon) + torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) # quantization ops @@ -130,13 +142,13 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: - return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, - thy) + return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, + thx, thy) def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: - return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters) + return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # gptq @@ -144,27 +156,27 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_exllama: bool, bit: int) -> torch.Tensor: - return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, - b_g_idx, use_exllama, bit) + return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, + b_g_idx, use_exllama, bit) def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: - vllm_ops.gptq_shuffle(q_weight, q_perm, bit) + torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) # squeezellm def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, lookup_table: torch.Tensor) -> None: - vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table) + torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table) # marlin def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, - size_n, size_k) + return torch.ops._C.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m, + size_n, size_k) # marlin_24 @@ -172,9 +184,9 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int) -> torch.Tensor: - return vllm_ops.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, - workspace, num_bits, size_m, size_n, - size_k) + return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, + workspace, num_bits, size_m, + size_n, size_k) # cutlass @@ -188,7 +200,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, n = b.shape[1] out = torch.empty((m, n), dtype=out_dtype, device=a.device) - vllm_ops.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) + torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) return out @@ -198,21 +210,22 @@ def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, codebooks: torch.Tensor, scales: torch.Tensor, codebook_partition_sizes: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - return vllm_ops.aqlm_gemm(input, codes, codebooks, scales, - codebook_partition_sizes, bias) + return torch.ops._C.aqlm_gemm(input, codes, codebooks, scales, + codebook_partition_sizes, bias) def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: torch.Tensor) -> torch.Tensor: - return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) + return torch.ops._C.aqlm_dequant(codes, codebooks, + codebook_partition_sizes) # gptq_marlin def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: - return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, - num_bits) + return torch.ops._C.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, + num_bits) def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, @@ -220,9 +233,9 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, num_bits: int, size_m: int, size_n: int, size_k: int, is_k_full: bool) -> torch.Tensor: - return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, - workspace, num_bits, size_m, size_n, - size_k, is_k_full) + return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, + workspace, num_bits, size_m, size_n, + size_k, is_k_full) # fp8 @@ -259,9 +272,9 @@ def scaled_fp8_quant( output = torch.empty_like(input, dtype=torch.float8_e4m3fn) if scale is None: scale = torch.zeros(1, device=input.device, dtype=torch.float32) - vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) + torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) else: - vllm_ops.static_scaled_fp8_quant(output, input, scale) + torch.ops._C.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -284,14 +297,14 @@ def scaled_int8_quant( output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. - vllm_ops.static_scaled_int8_quant(output, input, scale) + torch.ops._C.static_scaled_int8_quant(output, input, scale) return output, scale # dynamic-per-token quantization. input_scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32) - vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales) + torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales) return output, input_scales @@ -300,9 +313,16 @@ def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor) -> None: - vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size, - sorted_token_ids, experts_ids, - num_tokens_post_pad) + torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, + token_expert_indicies: torch.Tensor, + gating_output: float) -> None: + torch.ops._moe_C.topk_softmax(topk_weights, topk_ids, + token_expert_indicies, gating_output) def reshape_and_cache( @@ -314,8 +334,9 @@ def reshape_and_cache( kv_cache_dtype: str, kv_scale: float, ) -> None: - vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, kv_scale) + torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, kv_scale) def reshape_and_cache_flash( @@ -326,25 +347,115 @@ def reshape_and_cache_flash( slot_mapping: torch.Tensor, kv_cache_dtype: str, ) -> None: - vllm_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype) + torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype) def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor, block_mapping: torch.Tensor) -> None: - vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) + torch.ops._C_cache_ops.copy_blocks(key_caches, value_caches, block_mapping) def swap_blocks(src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor) -> None: - vllm_cache_ops.swap_blocks(src, dst, block_mapping) + torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) def convert_fp8(output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8") -> None: - vllm_cache_ops.convert_fp8(output, input, scale, kv_dtype) + torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) + + +def get_device_attribute(attribute: int, device: int) -> int: + return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) + + +def get_max_shared_memory_per_block_device_attribute(device: int) -> int: + # ruff: noqa: E501 + return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( + device) + + +# custom ar +def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, + handles: List[str], offsets: List[int], rank: int, + full_nvlink: bool) -> int: + return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles, + offsets, rank, full_nvlink) + + +def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int, + full_nvlink: bool) -> bool: + return torch.ops._C_custom_ar.should_custom_ar(inp, max_size, world_size, + full_nvlink) + + +def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) + +def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, + out: torch.Tensor) -> None: + torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out) -#TODO: cuda_utils, custom_ar + +def dispose(fa: int) -> None: + torch.ops._C_custom_ar.dispose(fa) + + +def meta_size() -> int: + return torch.ops._C_custom_ar.meta_size() + + +def register_buffer(fa: int, t: torch.Tensor, handles: List[str], + offsets: List[int]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets) + + +def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]: + return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) + + +def register_graph_buffers(fa: int, handles: List[str], + offsets: List[List[int]]) -> None: + torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + + +# punica +def dispatch_bgmv( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.Tensor, + layer_idx: int, + scale: float, +) -> None: + torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, + scale) + + +def dispatch_bgmv_low_level( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.Tensor, + layer_idx: int, + scale: float, + h_in: int, + h_out: int, + y_offset: int, +) -> None: + torch.ops._punica_C.dispatch_bgmv_low_level( + y, + x, + w_t_all, + indicies, + layer_idx, + scale, + h_in, + h_out, + y_offset, + ) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 070c074e511bc..8c64c2bfdeb8f 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -5,7 +5,7 @@ import torch from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache -from vllm._C import cache_ops +from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) @@ -47,11 +47,11 @@ def swap_blocks( ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) @staticmethod def copy_blocks( @@ -60,7 +60,7 @@ def copy_blocks( ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + ops.copy_blocks(key_caches, value_caches, src_to_dists) @dataclass @@ -285,7 +285,7 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - cache_ops.reshape_and_cache_flash( + ops.reshape_and_cache_flash( key, value, key_cache, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e92e6c5e2dc8d..9294068c64d1a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -247,7 +247,7 @@ def __init__( self.use_naive_attn = True if self.use_naive_attn: - self.attn_func = _naive_attention + self.attn_func = _sdpa_attention logger.debug("Using naive attention in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -342,11 +342,18 @@ def forward( # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + # sdpa math backend attention out = self.attn_func( query, key, value, prefill_meta.seq_lens, + num_tokens, + self.num_heads, + self.head_size, self.scale, ) else: @@ -402,45 +409,34 @@ def forward( return output.view(num_tokens, hidden_size) -def _naive_attention( +def _sdpa_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, seq_lens: List[int], + num_tokens: int, + num_heads: int, + head_size: int, scale: float, ) -> torch.Tensor: - output = torch.empty_like(query) start = 0 - for _, seq_len in enumerate(seq_lens): + output = torch.empty((num_tokens, num_heads, head_size), + dtype=query.dtype, + device=query.device) + + for seq_len in seq_lens: end = start + seq_len - out = _naive_masked_attention( - query[start:end], - key[start:end], - value[start:end], - scale, - ) - # TODO(woosuk): Unnecessary copy. Optimize. - output[start:end].copy_(out) - start += seq_len + with torch.backends.cuda.sdp_kernel(enable_math=True, + enable_flash=False, + enable_mem_efficient=False): + sub_out = torch.nn.functional.scaled_dot_product_attention( + query[:, start:end, :], + key[:, start:end, :], + value[:, start:end, :], + dropout_p=0.0, + is_causal=True, + scale=scale).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end return output - - -def _naive_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, -) -> torch.Tensor: - seq_len, head_size, head_dim = query.shape - attn_mask = torch.triu(torch.ones(seq_len, - seq_len, - dtype=query.dtype, - device=query.device), - diagonal=1) - attn_mask = attn_mask * torch.finfo(query.dtype).min - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out diff --git a/vllm/config.py b/vllm/config.py index 4efdb6cab52c4..fa296cd626f17 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -5,7 +5,7 @@ Union) import torch -from transformers import PretrainedConfig +from transformers import PretrainedConfig, PreTrainedTokenizerBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -164,12 +164,8 @@ def _verify_embedding_mode(self) -> None: def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) if quant_cfg is None: - # SparseML uses a "compression_config" with a "quantization_config". - compression_cfg = getattr(self.hf_config, "compression_config", - None) - if compression_cfg is not None: - quant_cfg = compression_cfg.get("quantization_config", None) - + # compress-tensors uses a "compression_config" key + quant_cfg = getattr(self.hf_config, "compression_config", None) return quant_cfg def _verify_quantization(self) -> None: @@ -1119,6 +1115,16 @@ def get_image_input_enum_type(cls, value: str) -> ImageInputType: f"Expecting to choose from " f"{[x.name for x in cls.ImageInputType]}.") from e + #TODO(ywang96): make this a cached property once we refactor the + # VisionLanguageConfig class. + def get_image_token_text( + self, tokenizer: PreTrainedTokenizerBase) -> Tuple[str, str]: + """Get the image token placeholder text to be inserted into the + text prompt and the string representation of the image token id. + """ + image_token_str = tokenizer.decode(self.image_token_id) + return image_token_str * self.image_feature_size, image_token_str + def as_cli_args_dict(self) -> Dict[str, Any]: """Flatten vision language config to pure args. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0159053b4dc6a..bb37c5f313617 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -423,7 +423,9 @@ def _schedule_running( num_running_seqs = seq_group.get_max_num_running_seqs() budget.subtract_num_seqs(seq_group.request_id, num_running_seqs) - if curr_loras is not None and seq_group.lora_int_id > 0: + + if (curr_loras is not None and seq_group.lora_int_id > 0 + and seq_group.lora_int_id in curr_loras): curr_loras.remove(seq_group.lora_int_id) if running_queue: diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a3902aecb3793..4a0e19bc0c159 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -6,6 +6,7 @@ from torch.distributed import ProcessGroup import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import ( @@ -15,7 +16,11 @@ try: import pynvml - from vllm._C import custom_ar + # Simulate ImportError if custom_ar ops are not supported. + if not ops.is_custom_op_supported("_C_custom_ar::meta_size"): + raise ImportError("custom_ar", __file__) + + custom_ar = True @contextmanager def _nvml(): @@ -27,7 +32,7 @@ def _nvml(): except ImportError: # For AMD GPUs - custom_ar = None + custom_ar = False pynvml = None @contextmanager @@ -97,7 +102,7 @@ def __init__(self, self._IS_CAPTURING = False self.disabled = True - if custom_ar is None: + if not custom_ar: # disable because of missing custom allreduce library # e.g. in a non-cuda environment return @@ -175,7 +180,7 @@ def __init__(self, # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate # allreduce results. - self.meta = torch.zeros(custom_ar.meta_size() + max_size, + self.meta = torch.zeros(ops.meta_size() + max_size, dtype=torch.uint8, device=self.device) # This is a pre-registered IPC buffer. In eager mode, input tensors @@ -196,9 +201,8 @@ def __init__(self, self.world_size = world_size handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink - self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, - handles, offsets, rank, - self.full_nvlink) + self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles, + offsets, rank, self.full_nvlink) self.register_buffer(self.buffer) @contextmanager @@ -252,31 +256,31 @@ def _gather_ipc_meta(self, shard_data): def register_buffer(self, inp: torch.Tensor): handles, offsets = self._get_ipc_meta(inp) - custom_ar.register_buffer(self._ptr, inp, handles, offsets) + ops.register_buffer(self._ptr, inp, handles, offsets) def register_graph_buffers(self): - handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr) + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) logger.info("Registering %d cuda graph addresses", len(offset)) - custom_ar.register_graph_buffers(self._ptr, handles, offsets) + ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): - return custom_ar.should_custom_ar(inp, self.max_size, self.world_size, - self.full_nvlink) + return ops.should_custom_ar(inp, self.max_size, self.world_size, + self.full_nvlink) # all reduce, assuming inp tensor is IPC registered with register_buffer, # or, in the context of cuda graphs, register_graph_buffers def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: out = torch.empty_like(inp) - custom_ar.all_reduce_reg(self._ptr, inp, out) + ops.all_reduce_reg(self._ptr, inp, out) return out # all reduce, assuming inp tensor is NOT IPC registered def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: out = torch.empty_like(inp) - custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) + ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -304,7 +308,7 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: def close(self): if not self.disabled and self._ptr: - custom_ar.dispose(self._ptr) + ops.dispose(self._ptr) self._ptr = 0 def __del__(self): diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py index 24ef3cb45b19d..4b89a23dfc463 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -166,7 +166,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: and (not os.path.exists(path))): # only the local master process (with local_rank == 0) can # enter this block to calculate the cache - logger.info("generating GPU P2P access cache for in %s", path) + logger.info("generating GPU P2P access cache in %s", path) cache = {} for _i in range(num_dev): for _j in range(num_dev): diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 95417718b51fe..e7503b9655830 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -183,6 +183,16 @@ async def authentication(request: Request, call_next): served_model_names = [args.model] engine_args = AsyncEngineArgs.from_cli_args(args) + + # Enforce pixel values as image input type for vision language models + # when serving with API server + if engine_args.image_input_type is not None and \ + engine_args.image_input_type.upper() != "PIXEL_VALUES": + raise ValueError( + f"Invalid image_input_type: {engine_args.image_input_type}. " + "Only --image-input-type 'pixel_values' is supported for serving " + "vision language models with the vLLM API server.") + engine = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index fa33318786b9a..9424ccc959d11 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -346,6 +346,7 @@ class CompletionRequest(OpenAIBaseModel): le=torch.iinfo(torch.long).max) stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 @@ -482,6 +483,14 @@ def check_logprobs(cls, data): " in the interval [0, 5].")) return data + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when stream is True.") + return data + class EmbeddingRequest(BaseModel): # Ordered by official OpenAI API documentation diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 883567abf415b..dae60e4ec99f1 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,15 +1,16 @@ import codecs import time -from dataclasses import dataclass -from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List, - Optional) +from dataclasses import dataclass, field +from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, + List, Optional) from typing import Sequence as GenericSequence from typing import TypedDict, Union, cast, final from fastapi import Request -from openai.types.chat import ChatCompletionContentPartTextParam +from openai.types.chat import (ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam) -from vllm.config import ModelConfig +from vllm.config import ModelConfig, VisionLanguageConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( ChatCompletionContentPartParam, ChatCompletionLogProb, @@ -21,9 +22,13 @@ FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) +from vllm.multimodal.image import ImagePixelData +from vllm.multimodal.utils import (async_get_and_parse_image, + get_full_image_text_prompt) from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.utils import random_uuid @@ -40,6 +45,8 @@ class ConversationMessage(TypedDict): @dataclass(frozen=True) class ChatMessageParseResult: messages: List[ConversationMessage] + image_futures: List[Awaitable[ImagePixelData]] = field( + default_factory=list) class OpenAIServingChat(OpenAIServing): @@ -94,19 +101,76 @@ def _parse_chat_message_content_parts( parts: Iterable[ChatCompletionContentPartParam], ) -> ChatMessageParseResult: texts: List[str] = [] + image_futures: List[Awaitable[ImagePixelData]] = [] - for _, part in enumerate(parts): + vlm_config: Optional[VisionLanguageConfig] = getattr( + self.engine.engine, "vision_language_config", None) + model_config = getattr(self.engine.engine, "model_config", None) + + for part in parts: part_type = part["type"] if part_type == "text": text = cast(ChatCompletionContentPartTextParam, part)["text"] texts.append(text) + elif part_type == "image_url": + if vlm_config is None: + raise ValueError( + "'image_url' input is not supported as the loaded " + "model is not multimodal.") + + elif len(image_futures) == 0: + assert self.tokenizer is not None + image_url = cast(ChatCompletionContentPartImageParam, + part)["image_url"] + + if image_url.get("detail", "auto") != "auto": + logger.warning( + "'image_url.detail' is currently not supported and " + "will be ignored.") + + image_future = async_get_and_parse_image(image_url["url"]) + image_futures.append(image_future) + + else: + raise NotImplementedError( + "Multiple 'image_url' input is currently not supported." + ) + else: raise NotImplementedError(f"Unknown part type: {part_type}") - messages = [ConversationMessage(role=role, content="\n".join(texts))] + text_prompt = "\n".join(texts) + + if vlm_config is not None and len(image_futures): - return ChatMessageParseResult(messages=messages) + (image_token_prompt, + image_token_str) = vlm_config.get_image_token_text(self.tokenizer) + + # NOTE: If image token string (e.g, ) is already present + # in the text prompt, we assume it follows the same format required + # by the engine. + if image_token_str in text_prompt: + logger.warning( + "Detected image token string in the text prompt. " + "Skipping prompt formatting.") + messages = [ + ConversationMessage(role=role, content=text_prompt) + ] + + else: + full_prompt = get_full_image_text_prompt( + image_prompt=image_token_prompt, + text_prompt=text_prompt, + config=model_config) + messages = [ + ConversationMessage(role=role, content=full_prompt) + ] + else: + messages = [ConversationMessage(role=role, content=text_prompt)] + + return ChatMessageParseResult(messages=messages, + image_futures=image_futures) def _parse_chat_message_content( self, @@ -116,10 +180,10 @@ def _parse_chat_message_content( content = message.get("content") if content is None: - return ChatMessageParseResult(messages=[]) + return ChatMessageParseResult(messages=[], image_futures=[]) if isinstance(content, str): messages = [ConversationMessage(role=role, content=content)] - return ChatMessageParseResult(messages=messages) + return ChatMessageParseResult(messages=messages, image_futures=[]) return self._parse_chat_message_content_parts(role, content) @@ -144,11 +208,13 @@ async def create_chat_completion( try: conversation: List[ConversationMessage] = [] + image_futures: List[Awaitable[ImagePixelData]] = [] for msg in request.messages: - parsed_msg = self._parse_chat_message_content(msg) + chat_parsed_result = self._parse_chat_message_content(msg) - conversation.extend(parsed_msg.messages) + conversation.extend(chat_parsed_result.messages) + image_futures.extend(chat_parsed_result.image_futures) prompt = self.tokenizer.apply_chat_template( conversation=conversation, @@ -159,6 +225,17 @@ async def create_chat_completion( logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) + # Fetch image data + image_data: Optional[ImagePixelData] = None + try: + if len(image_futures): + # since we support only single image currently + assert len(image_futures) == 1 + image_data = await image_futures[0] + except Exception as e: + logger.error("Error in loading image data: %s", e) + return self.create_error_response(str(e)) + request_id = f"cmpl-{random_uuid()}" try: # Tokenize/detokenize depending on prompt format (string/token list) @@ -183,11 +260,15 @@ async def create_chat_completion( except ValueError as e: return self.create_error_response(str(e)) + inputs: PromptInputs = { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids, + } + if image_data is not None: + inputs["multi_modal_data"] = image_data + result_generator = self.engine.generate( - { - "prompt": prompt_text, - "prompt_token_ids": prompt_ids - }, + inputs, sampling_params, request_id, lora_request, @@ -360,25 +441,24 @@ async def chat_completion_stream_generator( yield f"data: {data}\n\n" finish_reason_sent[i] = True - if (request.stream_options - and request.stream_options.include_usage): - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=previous_num_tokens[i], - total_tokens=prompt_tokens + - previous_num_tokens[i], - ) + if (request.stream_options + and request.stream_options.include_usage): + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=previous_num_tokens[i], + total_tokens=prompt_tokens + previous_num_tokens[i], + ) - final_usage_chunk = ChatCompletionStreamResponse( - id=request_id, - object=chunk_object_type, - created=created_time, - choices=[], - model=model_name, - usage=final_usage) - final_usage_data = (final_usage_chunk.model_dump_json( - exclude_unset=True, exclude_none=True)) - yield f"data: {final_usage_data}\n\n" + final_usage_chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" except ValueError as e: # TODO: Use a vllm-specific Validation Error diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 572878b5527dc..c3c40f2b97d14 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -264,7 +264,8 @@ async def completion_stream_generator( ) else: final_usage = None - response_json = CompletionStreamResponse( + + chunk = CompletionStreamResponse( id=request_id, created=created_time, model=model_name, @@ -276,10 +277,27 @@ async def completion_stream_generator( finish_reason=finish_reason, stop_reason=stop_reason, ) - ], - usage=final_usage, - ).model_dump_json(exclude_unset=True) + ]) + if (request.stream_options + and request.stream_options.include_usage): + chunk.usage = None + + response_json = chunk.model_dump_json(exclude_unset=True) yield f"data: {response_json}\n\n" + + if (request.stream_options + and request.stream_options.include_usage): + final_usage_chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[], + usage=final_usage, + ) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + except ValueError as e: # TODO: Use a vllm-specific Validation Error data = self.create_streaming_error_response(str(e)) diff --git a/vllm/envs.py b/vllm/envs.py index 7d5c7371b7741..b140aa6d658e6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -29,6 +29,7 @@ VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" + VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None @@ -216,6 +217,11 @@ # Both spawn and fork work "VLLM_WORKER_MULTIPROC_METHOD": lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), + + # Timeout for fetching images when serving multimodal models + # Default is 5 seconds + "VLLM_IMAGE_FETCH_TIMEOUT": + lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), } # end-env-vars-definition diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index c87bed54726fc..7ecaa450f1758 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -4,16 +4,21 @@ import torch +from vllm import _custom_ops as ops + + +def _check_punica_support(): + if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"): + return -def _raise_import_error(e): if torch.cuda.get_device_capability() < (8, 0): raise ImportError( - "punica LoRA kernels require compute capability >= 8.0") from e + "punica LoRA kernels require compute capability >= 8.0") else: raise ImportError( "punica LoRA kernels could not be imported. If you built vLLM " "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " - "was set.") from e + "was set.") def bgmv( @@ -41,12 +46,9 @@ def bgmv( layer_idx: Layer index of the weight matrices. scale: Scaling factor. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) + _check_punica_support() - punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, @@ -75,11 +77,9 @@ def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, y_offset: Offset to apply to the starting column of y. y_slice_size: Size of the y column slice. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) - punica_kernels.dispatch_bgmv_low_level( + _check_punica_support() + + ops.dispatch_bgmv_low_level( y, x, w_t_all, @@ -122,10 +122,7 @@ def add_lora(y: torch.Tensor, scale: Scaling factor. buffer: Optional. Shape: `[B, R]`. Temporary buffer. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) + _check_punica_support() r = wb_t_all.size(-1) if buffer is None: @@ -135,9 +132,8 @@ def add_lora(y: torch.Tensor, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) - punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, - scale) + ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) + ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale) def add_lora_slice(y: torch.Tensor, @@ -176,10 +172,7 @@ def add_lora_slice(y: torch.Tensor, y_offset: Offset to apply to the starting column of y. y_slice_size: Size of the y column slice. """ - try: - import vllm._punica_C as punica_kernels - except ImportError as e: - _raise_import_error(e) + _check_punica_support() r = wb_t_all.size(-1) if buffer is None: @@ -189,7 +182,7 @@ def add_lora_slice(y: torch.Tensor, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - punica_kernels.dispatch_bgmv_low_level( + ops.dispatch_bgmv_low_level( buffer, x, wa_t_all, @@ -200,7 +193,7 @@ def add_lora_slice(y: torch.Tensor, buffer.size(1), 0, ) - punica_kernels.dispatch_bgmv_low_level( + ops.dispatch_bgmv_low_level( y, buffer, wb_t_all, diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index b0198a50b1c52..4a86c16cf64db 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -94,13 +94,12 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: is_lora_a whether the tensor is lora_a or lora_b. """ parts = name.split(".") - assert parts[0] == "base_model" - assert parts[1] == "model" - if parts[-1] == "weight": - assert parts[-2] == "lora_A" or parts[-2] == "lora_B" - return ".".join(parts[2:-2]), parts[-2] == "lora_A" - if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": - return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" + if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model": + if parts[-1] == "weight": + if parts[-2] == "lora_A" or parts[-2] == "lora_B": + return ".".join(parts[2:-2]), parts[-2] == "lora_A" + elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" - raise ValueError(f"{name} is unsupported format") + raise ValueError(f"{name} is unsupported LoRA weight") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1c6947137a1c9..4d0160ff296a0 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -8,7 +8,6 @@ import triton import triton.language as tl -import vllm._moe_C as moe_kernels from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -355,7 +354,7 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) - moe_kernels.topk_softmax( + ops.topk_softmax( topk_weights, topk_ids, token_expert_indicies, diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 0bc42beb66257..40b0df75a69a6 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -31,7 +31,7 @@ "gptq_marlin": GPTQMarlinConfig, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, - "sparseml": CompressedTensorsConfig, + "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, } diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 136a64623d7fb..0cf2bd927a800 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -20,16 +20,16 @@ def cutlass_fp8_supported() -> bool: capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] - version = torch.version.cuda - version = version[0] * 10 + version[1] + major, minor = torch.version.cuda.split(".") + version = int(major) * 10 + int(minor) # CUTLASS FP8 kernels need at least # CUDA 12.0 on SM90 systems (Hopper) # CUDA 12.4 on SM89 systems (Lovelace) gpu_is_supported = False - if capability >= 900: + if capability >= 90: gpu_is_supported = version > 120 - elif capability >= 890: + elif capability >= 89: gpu_is_supported = version > 124 return gpu_is_supported @@ -103,7 +103,7 @@ class Fp8LinearMethod(LinearMethodBase): 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn data type due to the limitation of torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) - + Args: quant_config: The quantization config. """ @@ -171,10 +171,10 @@ def create_weights( output_partition_sizes=output_partition_sizes, **extra_weight_attrs) - # ACTIVATION SCALE + # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": self._create_scale_param( - scale_name="act_scale", + scale_name="input_scale", layer=layer, output_partition_sizes=output_partition_sizes, **extra_weight_attrs) @@ -207,7 +207,7 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.logical_widths = None - layer.act_scale = None + layer.input_scale = None return # If checkpoint is fp8, requantize the separately quantized logical @@ -232,18 +232,18 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight layer.weight = Parameter(weight.t(), requires_grad=False) - # ACT_SCALE + # INPUT ACTIVATION SCALE # Dynamic: set to None (required input to ops.scaled_fp8_quant). - # Static: set to max of the act_scales (since they are equal). + # Static: set to max of the input_scales (since they are equal). if self.quant_config.activation_scheme == "dynamic": - layer.act_scale = None + layer.input_scale = None elif self.quant_config.activation_scheme == "static": - if not all_close_1d(layer.act_scale): + if not all_close_1d(layer.input_scale): raise ValueError( - "All the act_scales for the logical weights of a layer " - f"must be equal. But got {layer.act_scale}") - layer.act_scale = Parameter(layer.act_scale.max(), - requires_grad=False) + "All the input_scales for the logical weights of a " + f"layer must be equal. But got {layer.input_scale}") + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) else: raise ValueError( f"Unknown scheme {self.quant_config.activation_scheme}") @@ -254,11 +254,11 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.act_scale is None and x_scale computed from x. - # If static, layer.act_scale is scalar and x_scale set to act_scale. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. if bias is None and self.cutlass_fp8_supported: - qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) + qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) # Fused GEMM_DQ output = ops.cutlass_scaled_mm_dq( @@ -271,7 +271,7 @@ def apply(self, else: qinput, x_scale = ops.scaled_fp8_quant(x, - layer.act_scale, + layer.input_scale, batch_dim_padding=17) # Fused GEMM_DQ -- note we padded the input above because @@ -298,8 +298,8 @@ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module): - """Create "weight" (aka kv_scale) for an attention layer. - + """Create "weight" (aka kv_scale) for an attention layer. + Args: layer: The layer that is using the QuantizeMethodBase factory. """ diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 6174f0a974712..827591b227a2b 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -122,12 +122,9 @@ def get_quant_config(model_config: ModelConfig, hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) if hf_quant_config is None: - compression_config = getattr(model_config.hf_config, - "compression_config", None) - if compression_config is not None: - hf_quant_config = compression_config.get("quantization_config", - None) - + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", + None) if hf_quant_config is not None: return quant_cls.from_config(hf_quant_config) # In case of bitsandbytes/QLoRA, get quant config from the adapter model. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index a92abe6b5b8dc..4446914c67c8e 100755 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -33,6 +33,8 @@ "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), + "LlavaNextForConditionalGeneration": + ("llava_next", "LlavaNextForConditionalGeneration"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"), diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 8ff19a2015e0f..59af42445f323 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -247,11 +247,12 @@ class DbrxFusedNormAttention(nn.Module): def __init__( self, config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model - self.attn = DbrxAttention(config, quant_config) + self.attn = DbrxAttention(config, cache_config, quant_config) self.norm_1 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 3332bcc578460..67b32a08833b6 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,7 +1,7 @@ from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union import torch -from torch import nn +import torch.nn as nn # TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on # transformers' impl. from transformers import CLIPVisionModel, LlavaConfig @@ -51,10 +51,10 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -def _merge_vision_embeddings(input_ids: torch.Tensor, - inputs_embeds: torch.Tensor, - vision_embeddings: torch.Tensor, - image_token_id: int) -> torch.Tensor: +def merge_vision_embeddings(input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + vision_embeddings: torch.Tensor, + image_token_id: int) -> torch.Tensor: """In place merges in vision_embeddings with inputs_embeds.""" mask = (input_ids == image_token_id) @@ -151,7 +151,8 @@ def _parse_and_validate_image_input( return None if not isinstance(pixel_values, torch.Tensor): - raise ValueError("Incorrect type of pixel values") + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") return LlavaImagePixelInputs( type="pixel_values", @@ -166,7 +167,8 @@ def _parse_and_validate_image_input( return None if not isinstance(image_features, torch.Tensor): - raise ValueError("Incorrect type of image features") + raise ValueError("Incorrect type of image features. " + f"Got type: {type(image_features)}") return LlavaImageFeatureInputs( type="image_features", @@ -268,7 +270,7 @@ def forward( vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.get_input_embeddings(input_ids) - inputs_embeds = _merge_vision_embeddings( + inputs_embeds = merge_vision_embeddings( input_ids, inputs_embeds, vision_embeddings, self.vision_language_config.image_token_id) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py new file mode 100644 index 0000000000000..57cbd1e4a6018 --- /dev/null +++ b/vllm/model_executor/models/llava_next.py @@ -0,0 +1,469 @@ +from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict, + Union) + +import torch +import torch.nn as nn +from PIL import Image +# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on +# transformers' impl. +from transformers import CLIPVisionModel, LlavaNextConfig +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, unpad_image) +from typing_extensions import NotRequired + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData +from vllm.multimodal.image import ImagePixelData, get_dummy_image_data +from vllm.sequence import SamplerOutput, SequenceData + +from .llava import LlavaMultiModalProjector, merge_vision_embeddings +from .vlm_base import VisionLanguageModelBase + +logger = init_logger(__name__) + +_KEYS_TO_MODIFY_MAPPING = { + "language_model.lm_head": "lm_head", + "language_model.model": "language_model", +} + + +class LlavaNextImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: (batch_size, 1 + num_patches, num_channels, height, width)""" + + image_sizes: NotRequired[torch.Tensor] + """Shape: (batch_size, 2)""" + + +class LlavaNextImageFeatureInputs(TypedDict): + type: Literal["image_features"] + data: torch.Tensor + """Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)""" + + image_sizes: NotRequired[torch.Tensor] + """Shape: (batch_size, 2)""" + + +LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, + LlavaNextImageFeatureInputs] + + +def _get_dummy_image_data( + seq_len: int, + model_config: ModelConfig, + vlm_config: VisionLanguageConfig, +) -> Tuple[SequenceData, MultiModalData]: + seq_data, fake_mm_data = get_dummy_image_data(seq_len, model_config, + vlm_config) + + config_input_type = vlm_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + if config_input_type == ImageInputType.PIXEL_VALUES: + _, c, h, w = vlm_config.image_input_shape + mode = {1: "L", 3: "RGB"}[c] + fake_mm_data = ImagePixelData(Image.new(mode, (w, h), color=0)) + + return seq_data, fake_mm_data + + +def _image_pixel_processor( + data: ImagePixelData, + model_config: ModelConfig, + vlm_config: VisionLanguageConfig, +) -> Dict[str, torch.Tensor]: + image = data.image + + if isinstance(image, torch.Tensor): + pixel_values = image.to(model_config.dtype) + batch_size, _, _, h, w = pixel_values.shape + image_sizes = torch.tensor([(w, h) for _ in range(batch_size)]) + + return {"pixel_values": pixel_values, "image_sizes": image_sizes} + + # Temporary patch before dynamic number of image tokens is supported + _, _, h, w = vlm_config.image_input_shape + if (w, h) != (image.width, image.height): + logger.warning( + "Dynamic image shape is currently not supported. " + "Resizing input image to (%d, %d).", w, h) + + data.image = image.resize((w, h)) + + return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \ + ._default_input_processor(data, model_config, vlm_config) + + +@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor) +@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data) +class LlavaNextForConditionalGeneration(VisionLanguageModelBase): + """ + Args to `forward()`: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: For PIXEL_VALUES, expects a batch with shape + [1, num_patches, 3, 336, 336]. + image_features: For IMAGE_FEATURES, expects a batch with shape + [1, num_patches, 1176, 1024]. + """ + + def __init__(self, + config: LlavaNextConfig, + vision_language_config: VisionLanguageConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__(vision_language_config) + + # Update the type annotation from that of its superclass + self.config = config + + if self.vision_language_config.image_input_type == ( + VisionLanguageConfig.ImageInputType.PIXEL_VALUES): + self.vision_tower = CLIPVisionModel(config.vision_config) + else: + raise TypeError("Image features are not supported by LLaVA-NeXT") + + self.multi_modal_projector = LlavaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act) + + self.quant_config = quant_config + self.language_model = LlamaModel(config.text_config, cache_config, + quant_config) + self.unpadded_vocab_size = config.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.language_model.org_vocab_size) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size)) + + def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor: + _, num_channels, _, _ = self.vision_language_config.image_input_shape + + # Note that this is different from that of vLLM vision_language_config + # since the image is resized by the HuggingFace preprocessor + height = width = self.config.vision_config.image_size + + if list(data.shape[2:]) != [num_channels, height, width]: + raise ValueError( + f"The expected image tensor shape is batch dimension plus " + f"num_patches plus {[num_channels, height, width]}. " + f"You supplied {data.shape}. " + f"If you are using vLLM's entrypoint, make sure your " + f"supplied image input is consistent with " + f"image_input_shape in engine args.") + + return data + + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: + if list(data.shape[1:]) != [2]: + raise ValueError( + f"The expected image sizes shape is batch dimension plus " + f"{[2]}. You supplied {data.shape}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[LlavaNextImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) + image_features = kwargs.pop("image_features", None) + + expected_input_type = self.vision_language_config.image_input_type + ImageInputType = VisionLanguageConfig.ImageInputType + + if expected_input_type == ImageInputType.PIXEL_VALUES: + if image_features is not None: + raise ValueError( + "Expected pixel values but got image features") + if pixel_values is None: + return None + + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(image_sizes, torch.Tensor): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + + return LlavaNextImagePixelInputs( + type="pixel_values", + data=self._validate_image_pixels(pixel_values), + image_sizes=self._validate_image_sizes(image_sizes), + ) + + assert expected_input_type != ImageInputType.IMAGE_FEATURES, ( + "Failed to validate this at initialization time") + + return None + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, + pixel_values: torch.Tensor) -> torch.Tensor: + # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. + image_outputs = vision_tower(pixel_values.to(vision_tower.device), + output_hidden_states=True) + + image_features = image_outputs.hidden_states[ + self.config.vision_feature_layer] + + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + + def _merge_image_patch_embeddings(self, image_size: torch.Tensor, + patch_embeddings: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py + if strategy == "flat": + return patch_embeddings.flatten(0, 1) + + if strategy.startswith("spatial"): + orig_width, orig_height = image_size + height = width = self.config.vision_config.image_size \ + // self.config.vision_config.patch_size + + base_patch_embeds = patch_embeddings[0] + if height * width != base_patch_embeds.shape[0]: + raise ValueError( + "The number of patches is not consistent with the " + "image size.") + + if patch_embeddings.shape[0] > 1: + other_patch_embeds = patch_embeddings[1:] + + # image_aspect_ratio == "anyres" + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + (orig_width, orig_height), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + other_patch_embeds = other_patch_embeds \ + .view(num_patch_width, num_patch_height, height, width, -1) + + if "unpad" in strategy: + other_patch_embeds = other_patch_embeds \ + .permute(4, 0, 2, 1, 3).contiguous() \ + .flatten(1, 2).flatten(2, 3) + other_patch_embeds = unpad_image(other_patch_embeds, + image_size) + other_patch_embeds = torch.cat(( + other_patch_embeds, + self.image_newline[:, None, None] \ + .expand(*other_patch_embeds.shape[:-1], 1) \ + .to(other_patch_embeds.device), + ), dim=-1) + other_patch_embeds = other_patch_embeds \ + .flatten(1, 2).transpose(0, 1) + else: + other_patch_embeds = other_patch_embeds \ + .permute(0, 2, 1, 3, 4).contiguous() \ + .flatten(0, 3) + + merged_patch_embeddings = torch.cat( + (base_patch_embeds, other_patch_embeds), dim=0) + else: + if "unpad" in strategy: + merged_patch_embeddings = torch.cat( + (base_patch_embeds, + self.image_newline[None] \ + .to(base_patch_embeds.device) + ), dim=0) + else: + merged_patch_embeddings = base_patch_embeds + + return merged_patch_embeddings + + raise ValueError(f"Unexpected patch merge strategy: {strategy}") + + def _process_image_pixels( + self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + b, num_patches, c, h, w = pixel_values.shape + stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) + + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + + return stacked_image_features.view(b, num_patches, + *stacked_image_features.shape[-2:]) + + def _process_image_input( + self, image_input: LlavaNextImageInputs) -> torch.Tensor: + if image_input["type"] == "pixel_values": + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) + else: + image_features = image_input["data"] + + patch_embeddings = self.multi_modal_projector(image_features) + + image_sizes = image_input.get("image_sizes") + if image_sizes is None: + batch_size = image_input["data"].shape[0] + vision_config = self.config.vision_config + default_width = default_height = vision_config.image_size + image_sizes = torch.as_tensor([[default_width, default_height] + for _ in range(batch_size)]) + + merged_patch_embeddings = [ + self._merge_image_patch_embeddings(image_sizes[i], + patch_features, + strategy="spatial_unpad") + for i, patch_features in enumerate(patch_embeddings) + ] + + return torch.stack(merged_patch_embeddings, dim=0) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs: object, + ) -> SamplerOutput: + """Run forward pass for Llava 1.5. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + Concretely, consider a text prompt: + "\nUSER: What's the content of the image?\nASSISTANT:". + Tokenizer outputs: + [1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278, + 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]. + The to-be-inserted image has a size of 576 (24 * 24) along the context + length dimension. + `input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901, + 1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, + 9047, 13566, 29901]. + There will be 576 `32000` in the `input_ids`. + (32000 is the token id for ``.) + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + The model takes two types of image inputs: + PIXEL_VALUES and IMAGE_FEATURES. + The following shows how each maps to huggingface implementation. + PIXEL_VALUES: + - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353 + IMAGE_FEATURES: + - https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430 + before going through the multi modal projector. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: For PIXEL_VALUES, expects a batch with shape + [1, 3, 336, 336]. + image_features: For IMAGE_FEATURES, expects a batch with shape + [1, 576, 1024]. + """ + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + inputs_embeds = merge_vision_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.vision_language_config.image_token_id) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model(input_ids, + positions, + kv_caches, + attn_metadata, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head.weight, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # only doing this for language model part for now. + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) + use_default_weight_loading = False + if "vision" in name: + if self.vision_tower is not None: + # We only do sharding for language model and + # not vision model for now. + use_default_weight_loading = True + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + use_default_weight_loading = True + if use_default_weight_loading: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0f82549780ba4..3faf54d292b99 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -147,7 +147,7 @@ def __init__( "weight_loader": self.weight_loader, }) - # ACT_SCALE (for fp8) + # INPUT_SCALE (for fp8) if quant_config.activation_scheme == "static": if not quant_config.is_checkpoint_fp8_serialized: raise ValueError( @@ -182,11 +182,11 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, param_data[expert_id, :, :] = loaded_weight[:, shard] # Loading scales - if "act_scale" in weight_name or "w2.weight_scale" in weight_name: + if "input_scale" in weight_name or "w2.weight_scale" in weight_name: if param_data[expert_id] != 1 and (param_data[expert_id] - loaded_weight).abs() > 1e-5: raise ValueError( - "act_scales of w1 and w3 of a layer " + "input_scales of w1 and w3 of a layer " f"must be equal. But got {param_data[expert_id]} " f"vs. {loaded_weight}") param_data[expert_id] = loaded_weight @@ -225,9 +225,9 @@ def process_weights_after_loading(self): self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) else: - # If checkpoint is fp8 + static, cleanup act_scales. - # Since state_dict has an act_scale per expert but our kernels - # are passed one act_scale shared across all experts. + # If checkpoint is fp8 + static, cleanup input_scales. + # Since state_dict has an input_scale per expert but our kernels + # are passed one input_scale shared across all experts. if self.quant_config.activation_scheme == "static": if self.a13_scale is None or self.a2_scale is None: raise ValueError( @@ -237,7 +237,7 @@ def process_weights_after_loading(self): if (not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale)): print_warning_once( - "Found act_scales that are not equal for " + "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " "for each layer. ") @@ -576,7 +576,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # These are the activation scales for the experts # (param_name, weight_name, expert_id) ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.act_scale", expert_id) + f"experts.{expert_id}.{weight_name}.input_scale", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] ] diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py new file mode 100644 index 0000000000000..c6311d60e0bdd --- /dev/null +++ b/vllm/multimodal/utils.py @@ -0,0 +1,85 @@ +import base64 +from io import BytesIO +from typing import Optional, Union + +import aiohttp +from PIL import Image + +from vllm.config import ModelConfig +from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT +from vllm.multimodal.image import ImagePixelData + + +class ImageFetchAiohttp: + aiohttp_client: Optional[aiohttp.ClientSession] = None + + @classmethod + def get_aiohttp_client(cls) -> aiohttp.ClientSession: + if cls.aiohttp_client is None: + timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT) + connector = aiohttp.TCPConnector() + cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout, + connector=connector) + + return cls.aiohttp_client + + @classmethod + async def fetch_image(cls, image_url: str) -> Image.Image: + """Load PIL image from a url or base64 encoded openai GPT4V format""" + + if image_url.startswith('http'): + # Avoid circular import + from vllm import __version__ as VLLM_VERSION + + client = cls.get_aiohttp_client() + headers = {"User-Agent": f"vLLM/{VLLM_VERSION}"} + + async with client.get(url=image_url, headers=headers) as response: + response.raise_for_status() + image_raw = await response.read() + image = Image.open(BytesIO(image_raw)) + + # Only split once and assume the second part is the base64 encoded image + elif image_url.startswith('data:image'): + image = load_image_from_base64(image_url.split(',', 1)[1]) + + else: + raise ValueError("Invalid image url: A valid image url must start " + "with either 'data:image' or 'http'.") + + return image + + +async def async_get_and_parse_image(image_url: str) -> ImagePixelData: + with await ImageFetchAiohttp.fetch_image(image_url) as image: + return ImagePixelData(image) + + +def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str: + """encode image to base64 format.""" + + buffered = BytesIO() + if format == 'JPEG': + image = image.convert('RGB') + image.save(buffered, format) + return base64.b64encode(buffered.getvalue()).decode('utf-8') + + +def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: + """Load image from base64 format.""" + return Image.open(BytesIO(base64.b64decode(image))) + + +# TODO(ywang96): move this to a model registry for preprocessing vision +# language prompts based on the model type. +def get_full_image_text_prompt(image_prompt: str, text_prompt: str, + config: ModelConfig) -> str: + """Combine image and text prompts for vision language model depending on + the model architecture.""" + + if config.hf_config.model_type in ("llava", "llava_next"): + full_prompt = f"{image_prompt}\n{text_prompt}" + else: + raise ValueError( + f"Unsupported model type: {config.hf_config.model_type}") + return full_prompt diff --git a/vllm/utils.py b/vllm/utils.py index 2bd24d086f690..54d446b23350a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -22,6 +22,7 @@ import torch import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.logger import enable_trace_function_call, init_logger T = TypeVar("T") @@ -148,12 +149,8 @@ def is_neuron() -> bool: @lru_cache(maxsize=None) def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" - # NOTE: This import statement should be executed lazily since - # the Neuron-X backend does not have the `cuda_utils` module. - from vllm._C import cuda_utils - max_shared_mem = ( - cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu)) + ops.get_max_shared_memory_per_block_device_attribute(gpu)) # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py # will fail assert max_shared_mem > 0, "max_shared_mem can not be zero" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c59288b4f73c6..7879a5de5b7bd 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,3 +1,4 @@ +import gc import time import warnings from collections import defaultdict @@ -894,6 +895,10 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + # Prepare buffer for outputs. These will be reused for all batch sizes. + # It will be filled after the first graph capture. + hidden_states: Optional[torch.Tensor] = None + graph_batch_size = _get_graph_batch_size( self.scheduler_config.max_num_seqs) batch_size_capture_list = [ @@ -930,9 +935,11 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: self.set_active_loras(set(), lora_mapping) graph_runner = CUDAGraphRunner(self.model) - graph_runner.capture( + hidden_states = graph_runner.capture( input_tokens[:batch_size], input_positions[:batch_size], + hidden_states[:batch_size] + if hidden_states is not None else None, kv_caches, attn_metadata, memory_pool=self.graph_memory_pool, @@ -969,12 +976,13 @@ def capture( self, input_ids: torch.Tensor, positions: torch.Tensor, + hidden_states: Optional[torch.Tensor], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, memory_pool: Optional[Tuple[int, int]], stream: torch.cuda.Stream, **kwargs, - ) -> None: + ) -> torch.Tensor: assert self._graph is None # Run the model a few times without capturing the graph. # This is to make sure that the captured graph does not include the @@ -993,13 +1001,21 @@ def capture( # Capture the graph. self._graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): - hidden_states = self.model( + output_hidden_states = self.model( input_ids, positions, kv_caches, attn_metadata, **kwargs, ) + if hidden_states is not None: + hidden_states.copy_(output_hidden_states) + else: + hidden_states = output_hidden_states + del output_hidden_states + # make sure `output_hidden_states` is deleted + # in the graph's memory pool + gc.collect() torch.cuda.synchronize() # Save the input and output buffers. @@ -1012,7 +1028,7 @@ def capture( "block_tables": attn_metadata.decode_metadata.block_tables, } self.output_buffers = {"hidden_states": hidden_states} - return + return hidden_states def forward( self,