Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speculative decoding with EAGLE2 #1498

Closed
wants to merge 49 commits into from
Closed

Conversation

yukavio
Copy link
Collaborator

@yukavio yukavio commented Sep 24, 2024

Motivation

Accelerate the model inference by speculative inference (EAGLE2).

Modifications

It will be provided soon.

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@Qiubo1
Copy link

Qiubo1 commented Sep 26, 2024

hello, whether this code supports the multiple request sepc?

@merrymercy merrymercy mentioned this pull request Oct 6, 2024
37 tasks
@fengyang95
Copy link

fengyang95 commented Oct 9, 2024

Hi @yukavio Is there any recent progress or plan for this? Do you plan to support deepseek-v2?

@yukavio
Copy link
Collaborator Author

yukavio commented Oct 11, 2024

hello, whether this code supports the multiple request sepc?

Yes, I will support it.

@yukavio
Copy link
Collaborator Author

yukavio commented Oct 11, 2024

Hi @yukavio Is there any recent progress or plan for this? Do you plan to support deepseek-v2?

I have implemented the draft and verify stages and tested them on a single request. I am trying to migrate my code to the main branch due to the main branch has some significant changes about the controller and worker which are very important for my implementation.
I do not plan to support deepseek-v2 due to there is no open-source draft model of deepseek-v2 with eagle2 for testing.
I plan to implement this feature based on llama currently.

My plan:
Migrate code and test it: 1-2 days.
Implement remained code of single request speculative decoding: half or one week.
Implement remained code of speculative decoding with batch: one or two week.

@Qiubo1
Copy link

Qiubo1 commented Oct 16, 2024

THX, yukavio.I have some suggestions for this pr: 1. further more support more models, e.g. i think we should pop the eagle head from draft_extend_input_queue so we dont modify the origin llama model file. 2.i dont understand why we need so many SpecInfoPipline queue, spec only in decoding stage,if we dont need the draft_extend_input_queue at least.

@Qiubo1
Copy link

Qiubo1 commented Oct 16, 2024

Also i have another question, in the pr model_runner.py init kv cache twice in different tpworker, this results in the oom in gpu, if we merge the draft and tagret kv cache to increase the gpu utilization?

@yukavio
Copy link
Collaborator Author

yukavio commented Oct 16, 2024

Also i have another question, in the pr model_runner.py init kv cache twice in different tpworker, this results in the oom in gpu, if we merge the draft and tagret kv cache to increase the gpu utilization?

I have migrated the code to another branch :https://github.com/yukavio/sglang/tree/new_spec_infer and I will update the code to this PR lately. In the new implementation, I choose to run the draft worker and target model worker in one process instead of using many queues in SpecInfoPipline to communicate with draft work process and target process.

For memory management, I've fixed this bug in the new branch to ensure it won't raise an error during testing. But it may not be very efficient and I will improve it after I have finish the remained work in the plan.

@zhyncs
Copy link
Member

zhyncs commented Oct 21, 2024

@yukavio Hi yukavio Recently, SGLang has undergone some refactoring work. You need to merge the latest main to resolve the corresponding conflicts. Thanks!

@fengyang95
Copy link

@yukavio Hi, when is this PR expected to be merged? I've trained a draft model and am eager to try it out.

@yukavio
Copy link
Collaborator Author

yukavio commented Oct 21, 2024

@yukavio Hi yukavio Recently, SGLang has undergone some refactoring work. You need to merge the latest main to resolve the corresponding conflicts. Thanks!

OK, I am fixing some bugs in batch inference now. I will update the code to main branch after fixing them. Personally, I think the updated code can be used as the first version. The community could review this version of the implementation.

@yukavio
Copy link
Collaborator Author

yukavio commented Oct 21, 2024

@yukavio Hi, when is this PR expected to be merged? I've trained a draft model and am eager to try it out.

If all goes well I will finish the first version of development this week. When to merge into the main branch depends on community review and opinions.

@yukavio

This comment was marked as duplicate.

@fengyang95
Copy link

@yukavio Is CLI startup not supported currently? I encountered this error:

File "/opt/tiger/sglang/python/sglang/srt/server_args.py", line 613, in <dictcomp>
return cls(**{attr: getattr(args, attr) for attr in attrs})
^^^^^^^^^^^^^^^^^^^
AttributeError: 'Namespace' object has no attribute 'draft_runner_cache_size'

@zhyncs zhyncs self-assigned this Nov 11, 2024
@zhyncs zhyncs changed the title [WIP] Spec infer with EAGLE2 Speculative decoding with EAGLE2 Nov 14, 2024
@yukavio
Copy link
Collaborator Author

yukavio commented Nov 14, 2024

Add a memo: cutex should be added to dependency list of Sglang after review.

@merrymercy
Copy link
Contributor

@yukavio Can you resolve the conflicts?

@yukavio
Copy link
Collaborator Author

yukavio commented Nov 18, 2024

@yukavio Can you resolve the conflicts?

Fixed. Could you please help me to trigger the CI?

@yukavio
Copy link
Collaborator Author

yukavio commented Nov 18, 2024

CI has failed due to timeout.

Copy link
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did an initial code style review. I will follow up with a more careful logic review. We want to merge this soon.

Some guidelines:

  1. Split the big PR into smaller PRs. Merge small unreachable code first. (e.g., introduce two forward modes).
  2. Minimize the code change of the common scheduler and move most things into self-contained eagle-specific files

Sorry for the frequent changes on the main branch. We did a lot of refactoring to enable the overlap scheduler by default. That refactor has been finished (7d671e4) so we do not expect big changes after that. This PR is now our top priority.

examples/runtime/engine/EAGLE_offline_batch_inference.py Outdated Show resolved Hide resolved
python/sglang/srt/layers/attention/flashinfer_utils.py Outdated Show resolved Hide resolved
python/sglang/srt/layers/logits_processor.py Outdated Show resolved Hide resolved
python/sglang/srt/server_args.py Outdated Show resolved Hide resolved
python/sglang/srt/server_args.py Outdated Show resolved Hide resolved
python/sglang/srt/models/llama_eagle.py Outdated Show resolved Hide resolved
kernels = cutex.SourceModule(
"""
//cuda
__global__ void build_tree(Tensor<long, 2> parent_list, Tensor<long, 2> selected_index, Tensor<int, 1> verified_seq_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to do this with triton?

python/sglang/srt/speculative/eagle_utils.py Outdated Show resolved Hide resolved
positions = forward_batch.positions
if positions is None:
positions = clamp_position(forward_batch.seq_lens)
self.positions[:raw_num_token].copy_(positions)
Copy link
Contributor

@jjjjohnson jjjjohnson Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yukavio Do we need to assign spec_info.custom_mask to self.cuda_graph_custom_mask before replay? Looks like self.cuda_graph_custom_mask is not used when cuda graph replay()

Comment on lines +668 to +670
logits_output.next_token_logits = logits_output.next_token_logits_bak[
accept_index
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logits_output.next_token_logits = logits_output.next_token_logits_bak[
accept_index
]
self.next_token_logits_back = next_token_logits_back
logits_output.next_token_logits = logits_output.next_token_logits[
accept_index
]

Comment on lines +32 to +35
encoder_lens: torch.Tensor = None,
spec_info: "SpecInput" = None,
is_draft_runner: bool = False,
forward_batch: ForwardBatch = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style. If one argument can be None, we should use "Optional[xx] instead of xxx"

Suggested change
encoder_lens: torch.Tensor = None,
spec_info: "SpecInput" = None,
is_draft_runner: bool = False,
forward_batch: ForwardBatch = None,
encoder_lens: Optional[torch.Tensor] = None,
spec_info: Optional["SpecInput"] = None,
is_draft_runner: bool = False,
forward_batch: Optional[ForwardBatch] = None,

req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor] = None,
encoder_lens=None,
forward_batch=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct the type annotation, use Optional

@@ -130,8 +135,37 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
forward_batch.seq_lens_sum,
decode_wrappers=None,
encoder_lens=forward_batch.encoder_lens,
forward_batch=forward_batch,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we need this for `indices_updater_decode

@@ -164,52 +199,102 @@ def init_cuda_graph_state(self, max_bs: int):
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
]

self.cuda_graph_custom_mask = torch.zeros(
(max_bs * (self.max_context_len + 7) // 8),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is a rounding needed here?

paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs],
# speculative decoding verify stage
if spec_info is not None and not is_draft_runner:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use forward_mode.is_target_verify() as a condition.

req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: torch.Tensor = None,
spec_info: SpecInput = None,
is_draft_runner: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_draft_runner seems not useful.

# speculative decoding verify stage
if spec_info is not None and not is_draft_runner:
for i in range(self.num_wrappers):
decode_wrappers.append(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call it prefill_wrappers

Comment on lines +284 to +285
encoder_lens=None,
forward_batch=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keep type annotations.

hidden_states: Optional[torch.Tensor] = None
# backup of next_token_logits when use cuda graph
# id(next_token_logits_bak) == id(next_token_logits)
next_token_logits_bak: Optional[torch.Tensor] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not need this.

@@ -96,7 +96,7 @@ def __init__(self, server_args, port_args) -> None:
else:
# This port is checked free in PortArgs.init_new.
# We hold it first so that the next dp worker gets a different port
sockets.append(bind_port(tmp_port_args.nccl_port))
sockets.append(bind_port(tmp_port_args.nccl_port[0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use another variable instead of making it a list.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants