Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YaRN support implementation #1264

Merged
merged 16 commits into from
Nov 3, 2023
Merged

Conversation

Yard1
Copy link
Collaborator

@Yard1 Yard1 commented Oct 5, 2023

Supersedes #1161, thank you @viktor-ferenczi for laying down the groundwork :)

This PR implements support for YaRN models.

YaRN paper: https://arxiv.org/abs/2309.00071
YaRN repository: https://github.com/jquesnelle/yarn
Smallest model to test with: https://huggingface.co/NousResearch/Yarn-Llama-2-7b-64k

Closes #980

viktor-ferenczi and others added 6 commits September 30, 2023 09:54
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
@Yard1
Copy link
Collaborator Author

Yard1 commented Oct 5, 2023

cc @WoosukKwon @zhuohan123

@Yard1 Yard1 mentioned this pull request Oct 5, 2023
@viktor-ferenczi
Copy link
Contributor

Use format.sh to format the code (it is in the repository in the top folder).

@viktor-ferenczi
Copy link
Contributor

Tested it, but got the following error with model NousResearch/Yarn-Llama-2-7b-64k:

ValueError: `rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {'factor': 16.0, 'original_max_position_embeddings': 4096, 'type': 'yarn', 'finetuned': True}

@Yard1
Copy link
Collaborator Author

Yard1 commented Oct 6, 2023

@viktor-ferenczi can you post the full stack trace?

@Yard1
Copy link
Collaborator Author

Yard1 commented Oct 7, 2023

@viktor-ferenczi were you loading from a directory or from HF model id? I think it may not work in the former case as HF Transformers will use the built-in LlamaConfig class, which causes validation to fail. It should work if you specify the HF model id instead, as that should use the configuration present in the repository. This is a HF Transformers issue.

@viktor-ferenczi
Copy link
Contributor

viktor-ferenczi commented Oct 7, 2023

Thanks for the diagnosis.

We have to pass trust_remote_code=True to LLM() to bypass these checks, which would otherwise fail:

In transformers/models/llama/configuration_llama.py:

if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
...
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:

Now it can load the model and produces reasonable output. I go ahead testing YaRN itself.

@viktor-ferenczi
Copy link
Contributor

Tested OK with the NousResearch/Yarn-Llama-2-7b-64k model up to 8192 context length using the pass_key_evaluator.py test script.

@viktor-ferenczi
Copy link
Contributor

viktor-ferenczi commented Oct 7, 2023

Attempted to test at full 64k context length:

            max_model_len=65536,
            max_num_batched_tokens=65536,

Ran into this error:

  File "/home/viktor/dep/vllm-contrib/vllm/worker/worker.py", line 381, in _check_if_can_support_max_seq_len
    raise RuntimeError(
RuntimeError: vLLM cannot currently support max_model_len=65536 with block_size=16 on GPU with compute capability (8, 9) (required shared memory 264252.0 > available shared memory 101376). This will be fixed in a future release.

The check is:

if padded_max_seq_len * float32_bytes > max_shared_mem:
    raise RuntimeError(...)

Variables:

max_shared_mem = 101376
float32_bytes = 4
padded_max_seq_len = 65551 = 65536 + 15

From this padded_max_seq_len must be less than or equal to 25344, therefore max_model_len <= 25329

Retrying the test with max_model_len = 25328, because that's divisible by 16.

Result: The first generation gets frozen with 100% CPU core load (1 core) without using the GPU at all. The model is loaded into the GPU based on VRAM usage, it is just not used.

It is frozen in entrypoints/llm.py inside this loop:

        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
            for output in step_outputs:
                if output.finished:
                    outputs.append(output)
                    if use_tqdm:
                        pbar.update(1)

step_outputs is apparently always empty, because outputs remains empty after running it for minutes.

@viktor-ferenczi
Copy link
Contributor

viktor-ferenczi commented Oct 7, 2023

More test results:

With one 4090 GPU (24GB VRAM):

  • 8448: OK
  • 12288: OK
  • 13312: OK
  • 13824: OK

With two 4090 GPUs (2x24GB VRAM), tensor_parallel_size=2, swap_space=8:

Crash at 25328:

  File "/home/viktor/dep/vllm-contrib/vllm/model_executor/models/llama.py", line 215, in forward
    hidden_states = residual + hidden_states
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

In PassKeyEvaluator.evaluate() we need to add 8 tokens headroom to make sure the generated key also fits the full context size of the model, otherwise it can be truncated.

@Yard1
Copy link
Collaborator Author

Yard1 commented Oct 7, 2023

I think it might be frozen because you do not have enough KV cache memory to support the entire context length. This is orthogonal to the error raised about shared memory. The only "fix" here would be to use a GPU with more memory (and perhaps raise a warning in the scheduler if a new request cannot be scheduled due to lack of KV cache. That would be best done in a separate PR). I think quantization would also help.

@viktor-ferenczi
Copy link
Contributor

viktor-ferenczi commented Oct 7, 2023

I'm going to try on both 4090s then with tensor_parallel_size=2

Another way is to increase swap_space from the default 4 (GB) to higher, which adds more CPU memory to the cache.

@viktor-ferenczi
Copy link
Contributor

viktor-ferenczi commented Oct 7, 2023

Tested OK on two 4090 GPUs up to 25088 context size. (Not exhaustively, only at various points, see above.)

I cannot go much higher due to various vLLM limitations (see above).

@viktor-ferenczi
Copy link
Contributor

viktor-ferenczi commented Oct 7, 2023

@casper-hansen We got progress with YaRN here, but facing problems/limitations with vLLM at longer context sizes. Could you please suggest? Thanks! - See #1264 (comment)

@Yard1
Copy link
Collaborator Author

Yard1 commented Oct 7, 2023

I think memory related context length limitations should not block this PR and can be fixed independently. I know that the vLLM team is investigating ways to improve shared memory usage, for example.

I'll apply the suggestions you have made, @viktor-ferenczi.

@viktor-ferenczi
Copy link
Contributor

I agree about not blocking this PR because of unrelated context length limitations.

I'm not sure how #555 was tested, cannot find its test case in the source code. Wanted to reuse that test case for testing this PR as well, since they are related. (I could test manually with the pass key retrieval test I added, but it is not streamlined, not added to automation.)

@casper-hansen
Copy link
Contributor

@casper-hansen We got progress with YaRN here, but facing problems/limitations with vLLM at longer context sizes. Could you please suggest? Thanks! - See #1264 (comment)

I would advise to test on A100s to make sure it’s not a limit on the memory side. Other than that, it’s probably some niche area of vLLM that I have little insight into.

@zhuohan123 zhuohan123 added the new model Requests to new models label Oct 9, 2023
@WoosukKwon
Copy link
Collaborator

@Yard1 Would it be possible to first add Yarn without the tests? We also need to add tests for other RoPe scaling methods and it seems the Yarn test code will conflict with that.

@Yard1 Yard1 requested a review from WoosukKwon October 13, 2023 20:52
@WoosukKwon
Copy link
Collaborator

@Yard1 I got this error when tried the NousResearch/Yarn-Llama-2-7b-64k model. Is this because I'm using transformers 4.34.0?

  File "/home/wskwon/anaconda3/envs/vllm/lib/python3.9/site-packages/transformers/models/llama/configuration_llama.py", line 171, in _rope_scaling_validation
    raise ValueError(
ValueError: `rope_scaling` must be a dictionary with with two fields, `type` and `factor`, got {'factor': 16.0, 'original_max_position_embeddings': 4096, 'type': 'yarn', 'finetuned': True}

@Yard1
Copy link
Collaborator Author

Yard1 commented Oct 15, 2023

@WoosukKwon did you use trust_remote_code=True? It is required here.

Actually, we should probably make that clearer... maybe catch this exception and raise our own telling the user to set it to true.

@WoosukKwon
Copy link
Collaborator

@Yard1 Got it. It works after adding trust_remote_code.

However, I got the following error during the initial memory profiling:

  File "/home/wskwon/workspace/vllm/vllm/model_executor/layers/attention.py", line 203, in forward
    output = torch.empty_like(query)
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions

However, the error disappeared when I fixed max_model_len to a smaller number (e.g., 4096) than the default one (1M). Have you seen this too?

We saw the same kind of error when using Mistral model with AWQ (#1236). The error happened because there was integer overflow in pointer arithmetic inside the AWQ GEMM kernel, when the number of input tokens is too large. We fixed this in #1295 by changing some int variables to long long. Likewise, I feel one of our kernels does not work with 1M input tokens and silently reads/writes invalid memory address.

@Yard1
Copy link
Collaborator Author

Yard1 commented Oct 15, 2023

Yes, I think the default transformers input len is way too big. I got it to work successfully with 24K on an A100.

@WoosukKwon WoosukKwon mentioned this pull request Nov 2, 2023
3 tasks
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the configs of Yarn models break

def _get_and_verify_max_len(

For example, vLLM infers that the maximum length of NousResearch/Yarn-Llama-2-7b-64k is 1M rather than 64K because the max_position_embeddings in the model config is 64K and the rope scaling factor is 16.

On the other hand, for NousResearch/Yarn-Mistral-7b-128k model, vLLM regards the maximum length is 512K because max_position_embeddings in its config is 32K (which is the original model's maximum length after 4x RoPE scaling).

In conclusion, I feel we need a separate logic for Yarn models in the _get_and_verify_max_len function. WDYT?

vllm/model_executor/layers/attention.py Outdated Show resolved Hide resolved
@Yard1
Copy link
Collaborator Author

Yard1 commented Nov 2, 2023

@WoosukKwon Updated, PTAL!

@Yard1 Yard1 requested a review from WoosukKwon November 2, 2023 22:43
@Yard1
Copy link
Collaborator Author

Yard1 commented Nov 2, 2023

I discovered some sort of a correctness error with the probability tensor getting a non-finite value, marking as draft until investigation is complete.

@Yard1 Yard1 marked this pull request as draft November 2, 2023 23:03
@Yard1
Copy link
Collaborator Author

Yard1 commented Nov 2, 2023

Fixed, the int in kernel was overflowing.

@Yard1 Yard1 marked this pull request as ready for review November 2, 2023 23:25
csrc/pos_encoding_kernels.cu Outdated Show resolved Hide resolved
@casper-hansen
Copy link
Contributor

Hi @Yard1, just informing you that there is now a Mistral Yarn model out. Perhaps a good idea to test if this PR works with it.
https://huggingface.co/NousResearch/Yarn-Mistral-7b-64k

@Yard1
Copy link
Collaborator Author

Yard1 commented Nov 3, 2023

Good catch with the activation kernel @WoosukKwon !

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 LGTM. Thanks for submitting the PR! FYI, I also fixed the same kind of overflow error in activation_kernels.cu.

@WoosukKwon WoosukKwon merged commit 9f669a9 into vllm-project:main Nov 3, 2023
2 checks passed
@Yard1 Yard1 deleted the yarn_support branch November 3, 2023 21:13
@xiechengmude
Copy link

xiechengmude commented Nov 5, 2023

BTw whats the minimal resourses to deploy this 64k model ?

Would vllm support offload method in the future?>

hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Co-authored-by: Viktor Ferenczi <viktor@ferenczi.eu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
Signed-off-by: Antoni Baum <antoni.baum@protonmail.com>
Co-authored-by: Viktor Ferenczi <viktor@ferenczi.eu>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new model Requests to new models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support YaRN models (RoFormer implementation in rotary_embedding kernel)
6 participants