Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Incomplete tool calling response for pipeline-parallel vllm with ray #7194

Open
sfbemerk opened this issue Aug 6, 2024 · 20 comments
Open
Labels
bug Something isn't working unstale

Comments

@sfbemerk
Copy link

sfbemerk commented Aug 6, 2024

Your current environment

vllm v0.5.4

Setup A) single docker container with vllm, no pipeline-parallelism

docker run ... vllm/vllm-openai:v0.5.4 --model "meta-llama/Meta-Llama-3.1-70B-Instruct" --tensor-parallel-size 2 --max-model-len=4096

Setup B) two docker containers with ray + vllm (pipeline parallelism)

docker run -it ... --network host --entrypoint /bin/bash vllm/vllm-openai:v0.5.4

# start ray head node in one of the docker containers
ray start --head --disable-usage-stats

# start ray worker node in the other docker container
ray start --address='<IP-ADDRESS>:6379'

# start vllm in head node container
vllm serve "meta-llama/Meta-Llama-3.1-70B-Instruct" --tensor-parallel-size 1 --pipeline-parallel-size 2 --distributed-executor-backend ray --max-model-len=4096

The issue does not depend on the model; e.g. it also appears with meta-llama/Meta-Llama-3-70B-Instruct instead of Llama-3.1

🐛 Describe the bug

Without pipeline-parallelism, a request with tool calling is responded correctly with a valid tool call.
With pipeline-parallelism, the same request is responded with incomplete tool call (just a few tokens long, but still status 200).

Example request:

{
    "model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
    "max_tokens": 200,
    "seed": 8,
    "messages": [
        {
            "role": "user",
            "content": "What is the weather like in Berlin?"
        }
    ],
    "tools": [
		{
			"type": "function",
			"function": {
				"name": "get_current_weather",
				"description": "Get the current weather",
				"parameters": {
					"type": "object",
					"properties": {
						"location": {
							"type": "string",
							"description": "The city and state, e.g. San Francisco, CA"
						},
						"format": {
							"type": "string",
							"enum": ["celsius", "fahrenheit"],
							"description": "The temperature unit to use. Infer this from the users location."
						}
					},
					"required": ["location", "format"]
				}
			}
		}
	],
    "tool_choice": {"type": "function", "function": {"name": "get_current_weather"}}
}

Response for setup A (no pipeline-parallelism)

{
  "id": "chat-a9e517fd2f0e4ac3817896596c5cc907",
  "object": "chat.completion",
  "created": 1722929009,
  "model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "",
        "tool_calls": [
          {
            "id": "chatcmpl-tool-0b5f4340c31d42c19765aa5f755586a3",
            "type": "function",
            "function": {
              "name": "get_current_weather",
              "arguments": "{\"location\": \"Berlin\", \"format\": \"celsius\"}"
            }
          }
        ]
      },
      "logprobs": null,
      "finish_reason": "stop",
      "stop_reason": null
    }
  ],
  "usage": {
    "prompt_tokens": 251,
    "total_tokens": 265,
    "completion_tokens": 14
  }
}

The response for setup B (pipeline-parallelism with ray)

{
  "id": "chat-80f69a8c5403478ea0c75943e7bab3af",
  "object": "chat.completion",
  "created": 1722931399,
  "model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "",
        "tool_calls": [
          {
            "id": "chatcmpl-tool-b6517d27352947969bf3d3dc8b8c6b98",
            "type": "function",
            "function": {
              "name": "get_current_weather",
              "arguments": "{\"location"
            }
          }
        ]
      },
      "logprobs": null,
      "finish_reason": "stop",
      "stop_reason": null
    }
  ],
  "usage": {
    "prompt_tokens": 251,
    "total_tokens": 254,
    "completion_tokens": 3
  }
}
@sfbemerk sfbemerk added the bug Something isn't working label Aug 6, 2024
@loredunk
Copy link

loredunk commented Aug 6, 2024

i also occur this error in Yi-34B

@youkaichao
Copy link
Member

cc @andoorve

@youkaichao
Copy link
Member

how about using a single node and pp size 2 ( so that it uses multiprocessing backend)? does it still have this issue?

@sfbemerk
Copy link
Author

sfbemerk commented Aug 6, 2024

I just tested it, but the bug appears there as well:
Setup C) one docker container with 2 GPUs (pipeline-parallel, no ray)

docker run ... vllm/vllm-openai:v0.5.4 --model "meta-llama/Meta-Llama-3.1-70B-Instruct" --tensor-parallel-size 1 --pipeline-parallel-size 2 --max-model-len=4096

again the incomplete response:

{
  "id": "chat-28622b0c3d074caead37fb5441ca2407",
  "object": "chat.completion",
  "created": 1722964614,
  "model": "meta-llama/Meta-Llama-3.1-70B-Instruct",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "",
        "tool_calls": [
          {
            "id": "chatcmpl-tool-9dfa147ec16245d7a8efd67ddfc6012f",
            "type": "function",
            "function": {
              "name": "get_current_weather",
              "arguments": "{\"location"
            }
          }
        ]
      },
      "logprobs": null,
      "finish_reason": "stop",
      "stop_reason": null
    }
  ],
  "usage": {
    "prompt_tokens": 251,
    "total_tokens": 254,
    "completion_tokens": 3
  }
}

@loredunk
Copy link

loredunk commented Aug 7, 2024

I solved this, as you only need to add "min_tokens": **. I set it to 50 due to my completion length.

@andoorve
Copy link
Collaborator

andoorve commented Aug 7, 2024

I did some investigation, the problem is quite deep. In the case of guided decoding we are using stateful logits processors. However, in PP, the logits processors will get cloned when sent to the workers other than the driver workers. This is ok for the TP use cases since the logits processor is not cloned and the state lives as long as the sequence group does. This is an issue for beyond PP however, and will affect anything SPMD. For those cases though, one solution might be to have the logits processor live on the worker. Pinging @njhill as this is quite similar to the seed issue and he might have some suggestions on the best way to resolve it.

MP GPU Executor, [<vllm.model_executor.guided_decoding.outlines_logits_processors.JSONLogitsProcessor object at 0x7f57d577eaf0>]
Rank: 0, [<vllm.model_executor.guided_decoding.outlines_logits_processors.JSONLogitsProcessor object at 0x7f57d577eaf0>]
(VllmWorkerProcess pid=2199329) Rank: 1, [<vllm.model_executor.guided_decoding.outlines_logits_processors.JSONLogitsProcessor object at 0x7f5811ac5ac0>]

@youkaichao
Copy link
Member

I think we need to overhaul logits processor part. it should not be a part of sampling parameter.

sampling parameter should just store the constraint, like regex or json schema. and the constraint decoding state (i.e. FSM state) should live in each sequence. The FSM itself, with masks for every state, should be retrieved from some cache that maps from regex to FSM.

@andoorve
Copy link
Collaborator

andoorve commented Aug 8, 2024

Yes exactly

sampling parameter should just store the constraint, like regex or json schema.
Yup agreed

and the constraint decoding state (i.e. FSM state) should live in each sequence. The FSM itself, with masks for every state, should be retrieved from some cache that maps from regex to FSM.

We can keep the FSM as part of logits processor but map from sequence to FSM in the worker similar to torch Generator. Ideally this is done in a clean and future-proof way.

@rkooo567 You mention support for guided decoding in #7109. Is this something you're looking into?

@rkooo567
Copy link
Collaborator

rkooo567 commented Aug 8, 2024

@rkooo567 You mention support for guided decoding in #7109. Is this something you're looking into?

not exactly, but to make it work with SPMD, I believe we need exactly the same solution. So I may look into it soon. #7109

@rkooo567
Copy link
Collaborator

rkooo567 commented Aug 8, 2024

also totally agreed with @youkaichao. I think that's the direction we should go (and what we wanted to achieve with #5423).

@andoorve
Copy link
Collaborator

andoorve commented Aug 8, 2024

Yes @rkooo567 I think a lot of the stuff from SPMD is basically directly the same for PP.

@jon-chuang
Copy link
Contributor

jon-chuang commented Aug 8, 2024

There is already as a PR to make the stateful logit processor shareable

#5329

The idea is as described above; pass around logit_processor factories instead of instantiation, and then each Sequence (that lives on each worker) has its own instantiation of logit_processor

@jon-chuang
Copy link
Contributor

jon-chuang commented Aug 8, 2024

@maxdebayser are you going to continue your work in #5329?

@jon-chuang
Copy link
Contributor

jon-chuang commented Aug 8, 2024

To make #5329 work with #7109, I think the SequenceDataDelta will leave the SequenceData.logit_processors untouched, and logit_processor will obtain the state from SequenceData as per usual once it has been updated with SequenceDataDelta.

Meaning to say, probably minimal changes are required.

@njhill
Copy link
Member

njhill commented Aug 8, 2024

I think we need to overhaul logits processor part. it should not be a part of sampling parameter.

sampling parameter should just store the constraint, like regex or json schema. and the constraint decoding state (i.e. FSM state) should live in each sequence. The FSM itself, with masks for every state, should be retrieved from some cache that maps from regex to FSM.

Absolutely, this has been planned for some time - a larger overhaul of how logits processors work is planned e.g. #5423 but is stalled a bit given how it needs to fit with other changes. I feel we can make some incremental improvements in the meantime, potentially starting with #5329 that @jon-chuang referenced.

Re needing the state to be in the worker for PP etc, I know it's also required for SPMD but imo should just be part of making the batches stateful (same for the torch.Generators).

@maxdebayser
Copy link
Contributor

@jon-chuang , I've been focusing on other issues in the past couple of weeks, but if this can be an incremental step on the way of a broader refactoring as @njhill mentioned, I'll sync this PR with main again.

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Nov 12, 2024
@edallmeier
Copy link

edallmeier commented Dec 9, 2024

Version: vllm==0.6.4.post1

We've encountered an issue, which I strongly assume is related to this.

When hosting Qwen/Qwen2-VL-72B-Instruct-AWQ in distributed mode via ray (with "--tensor-parallel-size 4 --pipeline-parallel-size 2") it is not possible to use guided_json mode. It does not return an error code, but stops generating after 3 tokens: ´´´{"content"```. Increasing the min_tokens as mentioned above via extra_body does not resolve the issue. When disabling "guided_json" the model generates a response without an issue. When not using "--pipeline-parallel-size 2" it is possible to use guided_json properly.

If needed, I can provide more details/output, but currently this looks quite related.

@andoorve
Copy link
Collaborator

andoorve commented Dec 9, 2024

@njhill As far as I can tell we still expect this to be broken right? I think some of the PRs planned around this are still open as far as I can tell.

@github-actions github-actions bot added unstale and removed stale labels Dec 10, 2024
@richardliaw richardliaw added the ray anything related with ray label Dec 11, 2024
@ruisearch42
Copy link
Collaborator

I just tested it, but the bug appears there as well:
Setup C) one docker container with 2 GPUs (pipeline-parallel, no ray)

I'm removing the ray label since it appears in MP executor as well. cc @richardliaw

@ruisearch42 ruisearch42 removed the ray anything related with ray label Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working unstale
Projects
None yet
Development

No branches or pull requests