Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logprob in the overlapped mode #1795

Merged
merged 6 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pip install "sglang[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
```

**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.

### Method 2: From source
```
Expand All @@ -75,7 +75,7 @@ pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
```

**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.

### Method 3: Using docker
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
Expand Down
4 changes: 2 additions & 2 deletions docs/en/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pip install "sglang[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
```

**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.

### Method 2: From source
```
Expand All @@ -26,7 +26,7 @@ pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
```

**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.**
Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.

### Method 3: Using docker
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ class LogitsProcessorOutput:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
next_token_logprobs: torch.Tensor
next_token_logprobs: torch.Tensor = None

# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor
normalized_prompt_logprobs: torch.Tensor = None
# The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor
input_token_logprobs: torch.Tensor = None

# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
input_top_logprobs: List
input_top_logprobs: List = None
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
output_top_logprobs: List
output_top_logprobs: List = None


@dataclasses.dataclass
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result):

if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs
else:
# Move next_token_ids and logprobs to cpu
if batch.return_logprob:
Expand Down
41 changes: 38 additions & 3 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def forward_thread_func_(self):
while True:
self.has_inflight_batch = False
model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch:
break
self.has_inflight_batch = True
self.launch_event = threading.Event()

Expand All @@ -122,19 +124,48 @@ def forward_thread_func_(self):
] = next_token_ids

# Copy results to the CPU
if model_worker_batch.return_logprob:
logits_output.next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].to("cpu", non_blocking=True)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.to(
"cpu", non_blocking=True
)
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event = torch.cuda.Event(blocking=True)
copy_event.record()

self.launch_event.set()
self.copy_queue.put((copy_event, next_token_ids))
self.copy_queue.put((copy_event, logits_output, next_token_ids))

def copy_thread_func(self):
while True:
copy_event, next_token_ids = self.copy_queue.get()
copy_event, logits_output, next_token_ids = self.copy_queue.get()
if not copy_event:
break
while not copy_event.query():
time.sleep(1e-5)
self.output_queue.put((None, next_token_ids.tolist()))

if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)

self.output_queue.put((logits_output, next_token_ids.tolist()))

def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get()
Expand Down Expand Up @@ -172,3 +203,7 @@ def update_weights(self, recv_req: UpdateWeightReqInput):
recv_req.model_path, recv_req.load_format
)
return success, message

def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))
30 changes: 14 additions & 16 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def run_once():
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
)
return forward(input_ids, forward_batch.positions, forward_batch)
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits

for _ in range(2):
torch.cuda.synchronize()
Expand Down Expand Up @@ -318,23 +319,16 @@ def replay(self, forward_batch: ForwardBatch):

# Replay
self.graphs[bs].replay()
logits_output = self.output_buffers[bs]

# Unpad
if bs != raw_bs:
logits_output = LogitsProcessorOutput(
next_token_logits=logits_output.next_token_logits[:raw_bs],
next_token_logprobs=None,
normalized_prompt_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
next_token_logits = self.output_buffers[bs][:raw_bs]

# Extract logprobs
if forward_batch.return_logprob:
logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
)
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
next_token_logprobs=next_token_logprobs,
)
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if return_top_logprob:
Expand All @@ -343,7 +337,11 @@ def replay(self, forward_batch: ForwardBatch):
top_logprobs_nums=forward_batch.top_logprobs_nums,
)
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
logits_output.next_token_logprobs, logits_metadata
next_token_logprobs, logits_metadata
)[1]
else:
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
)

return logits_output
1 change: 0 additions & 1 deletion test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"test_openai_server.py",
"test_overlap_schedule.py",
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_retract_decode.py",
"test_server_args.py",
"test_skip_tokenizer_init.py",
Expand Down
Loading