Skip to content

Commit

Permalink
Fix logit processor bugs (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored May 12, 2024
1 parent 7023f41 commit aee4f52
Show file tree
Hide file tree
Showing 26 changed files with 166 additions and 257 deletions.
5 changes: 0 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ curl http://localhost:30000/generate \
Learn more about the argument format [here](docs/sampling_params.md).

### OpenAI Compatible API

In addition, the server supports an experimental OpenAI-compatible API.

```python
Expand Down Expand Up @@ -386,7 +385,6 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md).

## Benchmark And Performance

- Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
![llama_7b](assets/llama_7b.jpg)

Expand All @@ -410,7 +408,4 @@ https://github.com/sgl-project/sglang/issues/157
}
```

[![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-md.svg)](https://huggingface.co/papers/2312.07104)


We learned from the design and reused some code of the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), [LMQL](https://github.com/eth-sri/lmql).
2 changes: 2 additions & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Some Public API Definitions"""

import os
import re
from typing import Callable, List, Optional, Union

Expand Down Expand Up @@ -31,6 +32,7 @@ def decorator(func):

def Runtime(*args, **kwargs):
# Avoid importing unnecessary dependency
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from sglang.srt.server import Runtime

return Runtime(*args, **kwargs)
Expand Down
9 changes: 5 additions & 4 deletions python/sglang/backend/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@


class Anthropic(BaseBackend):
def __init__(self, model_name):
def __init__(self, model_name, *args, **kwargs):
super().__init__()

if isinstance(anthropic, Exception):
raise anthropic

self.model_name = model_name
self.chat_template = get_chat_template("claude")
self.client = anthropic.Anthropic(*args, **kwargs)

def get_chat_template(self):
return self.chat_template
Expand All @@ -41,7 +42,7 @@ def generate(
else:
system = ""

ret = anthropic.Anthropic().messages.create(
ret = self.client.messages.create(
model=self.model_name,
system=system,
messages=messages,
Expand All @@ -66,11 +67,11 @@ def generate_stream(
else:
system = ""

with anthropic.Anthropic().messages.stream(
with self.client.messages.stream(
model=self.model_name,
system=system,
messages=messages,
**sampling_params.to_anthropic_kwargs(),
) as stream:
for text in stream.text_stream:
yield text, {}
yield text, {}
2 changes: 1 addition & 1 deletion python/sglang/backend/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def select(
prompt_tokens.append(ret_token)

decision = choices[np.argmax(scores)]
return decision, scores, scores
return decision, scores, None, None


def openai_completion(client, retries=3, is_chat=None, prompt=None, **kwargs):
Expand Down
1 change: 0 additions & 1 deletion python/sglang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def select(
"sampling_params": {"max_new_tokens": 0},
"return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0),
"return_text_in_logprobs": True,
}
self._add_images(s, data)
res = http_request(
Expand Down
31 changes: 19 additions & 12 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,29 @@ def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
for i in range(all_logprobs.shape[0]):
k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k)
v_cpu = t.values.cpu().tolist()
p_cpu = t.indices.cpu().tolist()
v_cpu = t.values.tolist()
p_cpu = t.indices.tolist()
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, decode_top_logprobs
else:
prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens
for i in range(len(input_metadata.extend_seq_lens)):
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
for i in range(len(extend_seq_lens_cpu)):
if extend_seq_lens_cpu[i] == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
continue
k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
vs_cpu = t.values.cpu().tolist()
ps_cpu = t.indices.cpu().tolist()
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
prefill_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
)
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_lens_cpu[i]
return prefill_top_logprobs, decode_top_logprobs

def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
Expand Down Expand Up @@ -99,20 +102,24 @@ def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadat
all_logits = all_logits[:, : self.config.vocab_size]

all_logprobs = all_logits.float()
all_logits = None
del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)

prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata
)
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata
)
else:
prefill_top_logprobs = decode_top_logprobs = None

if input_metadata.forward_mode == ForwardMode.DECODE:
last_logprobs = all_logprobs
return last_logits, (
None,
None,
decode_top_logprobs,
None,
decode_top_logprobs,
last_logprobs,
)
else:
Expand All @@ -131,9 +138,9 @@ def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadat
)
return last_logits, (
prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs,
decode_top_logprobs,
normalized_prompt_logprobs,
last_logprobs,
)

Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class GenerateReqInput:
return_text_in_logprobs: bool = False
# Whether to stream output
stream: bool = False
# TODO: make all parameters a Union[List[T], T] to allow for batched requests

def post_init(self):
is_single = isinstance(self.text, str)
Expand Down
34 changes: 19 additions & 15 deletions python/sglang/srt/managers/router/infer_batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from enum import Enum, auto
from enum import IntEnum, auto
from typing import List

import numpy as np
Expand All @@ -9,15 +9,15 @@
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool


class ForwardMode(Enum):
class ForwardMode(IntEnum):
PREFILL = auto()
EXTEND = auto()
DECODE = auto()


class FinishReason(Enum):
LENGTH = auto()
class FinishReason(IntEnum):
EOS_TOKEN = auto()
LENGTH = auto()
STOP_STR = auto()


Expand All @@ -31,6 +31,7 @@ def __init__(self, rid, input_text, input_ids):
# Since jump forward may retokenize the prompt with partial outputs,
# we maintain the original prompt length to report the correct usage.
self.prompt_tokens = len(input_ids)

# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
self.completion_tokens_wo_jump_forward = 0
Expand All @@ -41,12 +42,11 @@ def __init__(self, rid, input_text, input_ids):
self.image_offset = 0
self.pad_value = None

# Sampling parameters
self.sampling_params = None
self.return_logprob = False
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.stream = False

# Check finish
self.tokenizer = None
self.finished = False
self.finish_reason = None
Expand All @@ -56,13 +56,17 @@ def __init__(self, rid, input_text, input_ids):
self.prefix_indices = []
self.last_node = None

# Logprobs
self.return_logprob = False
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.normalized_prompt_logprob = None
self.prefill_token_logprobs = None
self.decode_token_logprobs = None
self.normalized_prompt_logprob = None
self.prefill_top_logprobs = None
self.decode_top_logprobs = None

# For constrained decoding
# Constrained decoding
self.regex_fsm = None
self.regex_fsm_state = 0
self.jump_forward_map = None
Expand Down Expand Up @@ -165,8 +169,8 @@ class Batch:
out_cache_cont_end: torch.Tensor = None

# for processing logprobs
top_logprobs_nums: List[int] = None
return_logprob: bool = False
top_logprobs_nums: List[int] = None

# for multimodal
pixel_values: List[torch.Tensor] = None
Expand Down Expand Up @@ -321,8 +325,8 @@ def retract_decode(self):
)

retracted_reqs = []
seq_lens_np = self.seq_lens.cpu().numpy()
req_pool_indices_np = self.req_pool_indices.cpu().numpy()
seq_lens_cpu = self.seq_lens.cpu().numpy()
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
while self.token_to_kv_pool.available_size() < len(self.reqs):
idx = sorted_indices.pop()
req = self.reqs[idx]
Expand All @@ -338,8 +342,8 @@ def retract_decode(self):
# TODO: apply more fine-grained retraction

token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_np[idx]
][: seq_lens_np[idx]]
req_pool_indices_cpu[idx]
][: seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices)

self.filter_batch(sorted_indices)
Expand All @@ -363,7 +367,7 @@ def check_for_jump_forward(self):
# insert the old request into tree_cache
token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
req_pool_indices_cpu = self.req_pool_indices.tolist()
req_pool_idx = req_pool_indices_cpu[i]
indices = self.req_to_token_pool.req_to_token[
req_pool_idx, : len(token_ids_in_memory)
Expand Down
Loading

0 comments on commit aee4f52

Please sign in to comment.