Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] vllm updated its get_model function #1183

Closed
5 tasks done
zhaochenyang20 opened this issue Aug 22, 2024 · 4 comments
Closed
5 tasks done

[Bug] vllm updated its get_model function #1183

zhaochenyang20 opened this issue Aug 22, 2024 · 4 comments

Comments

@zhaochenyang20
Copy link
Collaborator

zhaochenyang20 commented Aug 22, 2024

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

Previously we used this in our ModelRunner:

        self.model = get_model(
            model_config=self.vllm_model_config,
            device_config=self.device_config,
            load_config=self.load_config,
            lora_config=None,
            multimodal_config=None,
            parallel_config=None,
            scheduler_config=None,
            cache_config=None,
        )

Note that vllm updated its get_model function several days ago and removed the multimodal_config parameter.

https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader/__init__.py

Also, even if we delete this parameter from our ModelRunner:

        self.model = get_model(
            model_config=self.vllm_model_config,
            device_config=self.device_config,
            load_config=self.load_config,
            lora_config=None,
            # multimodal_config=None,
            parallel_config=None,
            scheduler_config=None,
            cache_config=None,
        )

The unit test in test/srt/models/test_embedding_models.py still could not pass. Since I found that the new get_model function will load the LlamaEmbeddingModel class defined in vllm/model_executor/models/llama_embedding.py

But we want the get_model function to load our LlamaEmbeddingModel class in https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama_embedding.py

Could someone check what's happening in the new get_model function?

Reproduction

Simply use the latest code of vllm and sglang can reproduce it.

The attachment is my traceback:

(AlphaMeemory) (AlphaMeemory) chenyang@uclaml04:/data/chenyang/sglang/test/srt/models$ CUDA_VISIBALE_DEVICES=5 python3 test_embedding_models.py 
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.35it/s]
INFO 08-21 21:24:10 weight_utils.py:236] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.12s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.12it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:01<00:00,  1.08it/s]

Exception in ModelTpServer:
Traceback (most recent call last):
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 233, in exposed_step
    self.forward_step()
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 249, in forward_step
    self.forward_prefill_batch(new_batch)
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 540, in forward_prefill_batch
    output = self.model_runner.forward(batch, ForwardMode.EXTEND)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 527, in forward
    return self.forward_extend(batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 501, in forward_extend
    return self.model.forward(
           ^^^^^^^^^^^^^^^^^^^
TypeError: LlamaEmbeddingModel.forward() missing 1 required positional argument: 'attn_metadata'

Exception in ControllerSingle:
Traceback (most recent call last):
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/managers/controller_single.py", line 165, in start_controller_process
    controller.loop_for_forward()
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/managers/controller_single.py", line 102, in loop_for_forward
    out_pyobjs = self.tp_server.exposed_step(recv_reqs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 233, in exposed_step
    self.forward_step()
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 249, in forward_step
    self.forward_prefill_batch(new_batch)
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 540, in forward_prefill_batch
    output = self.model_runner.forward(batch, ForwardMode.EXTEND)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 527, in forward
    return self.forward_extend(batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 501, in forward_extend
    return self.model.forward(
           ^^^^^^^^^^^^^^^^^^^
TypeError: LlamaEmbeddingModel.forward() missing 1 required positional argument: 'attn_metadata'

E
======================================================================
ERROR: test_prefill_logits (__main__.TestEmbeddingModels.test_prefill_logits)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/data/chenyang/sglang/test/srt/models/test_embedding_models.py", line 67, in test_prefill_logits
    self.assert_close_prefill_logits(
  File "/data/chenyang/sglang/test/srt/models/test_embedding_models.py", line 41, in assert_close_prefill_logits
    with SRTRunner(
         ^^^^^^^^^^
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/test/runners.py", line 189, in __init__
    self.runtime = Runtime(
                   ^^^^^^^^
  File "/home/chenyang/miniconda3/envs/AlphaMeemory/lib/python3.11/site-packages/sglang/srt/server.py", line 533, in __init__
    raise RuntimeError(
RuntimeError: Initialization failed. Please see the error messages above.

----------------------------------------------------------------------
Ran 1 test in 20.086s

The error "TypeError: LlamaEmbeddingModel.forward() missing 1 required positional argument: 'attn_metadata'" told me that the get_model function is loading the LlamaEmbeddingModel in vllm but not sglang.

Environment

Python: 3.11.7 (main, Dec 15 2023, 18:12:31) [GCC 11.2.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA RTX A6000
GPU 0,1,2,3,4,5,6,7 Compute Capability: 8.6
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.3, V12.3.103
CUDA Driver Version: 545.23.08
PyTorch: 2.4.0+cu121
sglang: 0.2.13
flashinfer: 0.1.5+cu121torch2.4
triton: 3.0.0
transformers: 4.43.3
requests: 2.32.3
tqdm: 4.66.4
numpy: 1.26.4
aiohttp: 3.9.5
fastapi: 0.112.1
hf_transfer: 0.1.8
huggingface_hub: 0.24.3
interegular: 0.3.3
packaging: 24.1
PIL: 10.4.0
psutil: 6.0.0
pydantic: 2.8.2
uvicorn: 0.23.2
uvloop: 0.19.0
zmq: 26.0.3
vllm: 0.5.4
multipart: 0.0.9
openai: 1.40.3
anthropic: 0.33.0
NVIDIA Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X SYS SYS SYS SYS SYS SYS SYS 0-15,32-47 0 N/A
GPU1 SYS X SYS SYS SYS SYS SYS SYS 0-15,32-47 0 N/A
GPU2 SYS SYS X SYS SYS SYS SYS SYS 0-15,32-47 0 N/A
GPU3 SYS SYS SYS X SYS SYS SYS SYS 0-15,32-47 0 N/A
GPU4 SYS SYS SYS SYS X SYS SYS SYS 16-31,48-63 1 N/A
GPU5 SYS SYS SYS SYS SYS X SYS SYS 16-31,48-63 1 N/A
GPU6 SYS SYS SYS SYS SYS SYS X SYS 16-31,48-63 1 N/A
GPU7 SYS SYS SYS SYS SYS SYS SYS X 16-31,48-63 1 N/A

Legend:

X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks

ulimit soft: 1048576

@exceedzhang
Copy link
Contributor

+1

@HelloCard
Copy link

+1

(base) root@DESKTOP-O6DNFE1:/mnt/c/Windows/system32# python3 -m sglang.launch_server --model-path /mnt/e/Code/models/orca-2-13B-AWQ --quantization awq --tensor-parallel-size 2 --port 8000
server_args=ServerArgs(model_path='/mnt/e/Code/models/orca-2-13B-AWQ', tokenizer_path='/mnt/e/Code/models/orca-2-13B-AWQ', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', dtype='auto', trust_remote_code=False, context_length=None, quantization='awq', served_model_name='/mnt/e/Code/models/orca-2-13B-AWQ', chat_template=None, host='127.0.0.1', port=8000, additional_ports=[8001, 8002, 8003, 8004], mem_fraction_static=0.87, max_running_requests=None, max_num_reqs=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, tp_size=2, stream_interval=1, random_seed=962708718, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, api_key=None, file_storage_pth='SGLang_storage', dp_size=1, load_balance_method='round_robin', disable_flashinfer=False, disable_flashinfer_sampling=False, disable_radix_cache=False, disable_regex_jump_forward=False, disable_cuda_graph=False, disable_disk_cache=False, enable_torch_compile=False, enable_p2p_check=False, enable_mla=False, attention_reduce_in_fp32=False, efficient_weight_load=False, nccl_init_addr=None, nnodes=1, node_rank=None)
[gpu=0] Init nccl begin.
[gpu=1] Init nccl begin.
[gpu=0] Load weight begin. avail mem=20.53 GB
[gpu=1] Load weight begin. avail mem=20.53 GB
Exception in run_tp_server:
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 781, in run_tp_server
    model_server = ModelTpServer(
                   ^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 99, in __init__
    self.model_runner = ModelRunner(
                        ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 128, in __init__
    self.load_model()
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 172, in load_model
    self.model = get_model(
                 ^^^^^^^^^^
TypeError: get_model() got an unexpected keyword argument 'multimodal_config'

Process Process-1:1:
Initialization failed. controller_init_state: Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/controller_single.py", line 150, in start_controller_process
    controller = ControllerSingle(
                 ^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/controller_single.py", line 84, in __init__
    self.tp_server = ModelTpServer(
                     ^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 99, in __init__
    self.model_runner = ModelRunner(
                        ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 128, in __init__
    self.load_model()
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 172, in load_model
    self.model = get_model(
                 ^^^^^^^^^^
TypeError: get_model() got an unexpected keyword argument 'multimodal_config'

Initialization failed. detoken_init_state: init ok
Traceback (most recent call last):
  File "/root/miniconda3/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/root/miniconda3/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 781, in run_tp_server
    model_server = ModelTpServer(
                   ^^^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/managers/tp_worker.py", line 99, in __init__
    self.model_runner = ModelRunner(
                        ^^^^^^^^^^^^
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 128, in __init__
    self.load_model()
  File "/root/miniconda3/lib/python3.11/site-packages/sglang/srt/model_executor/model_runner.py", line 172, in load_model
    self.model = get_model(
                 ^^^^^^^^^^
TypeError: get_model() got an unexpected keyword argument 'multimodal_config'
/root/miniconda3/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

@zhaochenyang20
Copy link
Collaborator Author

zhaochenyang20 commented Aug 25, 2024

@exceedzhang @HelloCard

Hello, we've noticed this issue. You can use vllm==0.5.4 right now. And we will fix the bug ASAP.

pip install vllm==0.5.4

@zhyncs
Copy link
Member

zhyncs commented Aug 26, 2024

fixed with #1155

@zhyncs zhyncs closed this as completed Aug 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants