Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Classifier-Free Guidance #5825

Closed
Vermeille opened this issue Jun 25, 2024 · 4 comments
Closed

[RFC]: Classifier-Free Guidance #5825

Vermeille opened this issue Jun 25, 2024 · 4 comments

Comments

@Vermeille
Copy link

Motivation.

I am one of the authors of the paper Stay On Topic with Classifier-Free Guidance ( https://openreview.net/forum?id=RiM3cl9MdK&noteId=s1BXLL1YZD ) who has been nominated as ICML'24 Spotlight Paper. CFG is a sampling technique that allows LLMs to follow the prompt more closely at the cost of two forward passes per token as well as 2 kv caches. CFG brings non trivial improvements overall over standard benchmarks.

I would be extremely interested in having CFG implemented into vLLM. If possible, I would like to get a bit of guidance into the vLLM code base.

Proposed Change.

CFG contrasts the next token logits between two different prompt (a "positive prompt" a, and a "negative prompt" or "unconditional" b)

Here is the pseudo algorithm

while we sample:
    logits_a = log_softmax(model(prompt_a))
    logits_b = log_softmax(model(prompt_b))
    logits = logits_b + cfg_scale * (logits_a - logits_b)
    next_token = sample_from(logits)
    prompt_a.append(next_token)
    prompt_b.append(next_token)

As you can see this needs two concurrent kv-caches for an efficient implementation. I tried looking for how Speculative Decoding was implemented but this was quite complex, more than CFG needs.

Feedback Period.

No response

CC List.

No response

Any Other Things.

I am willing to implement it myself given enough guidance as this looks like a non trivial thing to implement. I think something similar to / reusing bits of Speculative Decoding might be used but the code is non trivial.

@Vermeille Vermeille added the RFC label Jun 25, 2024
@Vermeille
Copy link
Author

Up

@cadedaniel
Copy link
Collaborator

cadedaniel commented Jul 9, 2024

Hi @Vermeille. Great work.

For implementation in vLLM, this can be done at a similar layer to Speculative Decoding:

LLMEngine
CFGWorker
< logic which calls the underlying worker twice, does logit math, samples >
Worker (2x)

The primary benefit of this design is that you can manage two block tables in the existing LLMEngine and scheduler without any modification. This is done in speculative decoding (with draft model) by splitting the KV cache space evenly into two equally-sized regions [1]. Then the same block table can work for both models. You can actually prototype this relatively straightforwardly; the only major missing piece is you will need to have one of the Workers not load weights (e.g. weight loading is shared with other worker).

Alternatively, you can use a single worker and modify block tables with a constant offset so that you have independent KV cache.

Secondary benefits of this design are hardware agnosticity; your implementation can work with non-nvidia non-amd hardware backends.

[1]

def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.
This is done by profiling the scorer model (which is typically the
larger of the two). Then the total memory which would be used by the
scorer cache is divided evenly between the proposer and scorer model KV,
such that the number of blocks is equal in both KV caches.
"""
num_gpu_blocks, num_cpu_blocks = (
self.scorer_worker.determine_num_available_blocks())
scorer_cache_block_size_bytes = (
self.scorer_worker.get_cache_block_size_bytes())
proposer_cache_block_size_bytes = (
self.proposer_worker.get_cache_block_size_bytes())
new_num_gpu_blocks = split_num_cache_blocks_evenly(
scorer_cache_block_size_bytes, proposer_cache_block_size_bytes,
num_gpu_blocks)
return new_num_gpu_blocks, num_cpu_blocks

Copy link

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 25, 2024
Copy link

This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Nov 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants