Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
Wanglongzhi2001 committed Dec 17, 2024
1 parent 6ef6917 commit 62d4ec2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 20 deletions.
3 changes: 0 additions & 3 deletions llm/server/server/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ def read_from_env(self):
self.block_size = int(env.get("BLOCK_SIZE", 64))
self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))

# speculate decoding config
self.speculate_method = str(os.getenv("SPECULATE_METHOD", None))

# infer config
self.max_batch_size = int(env.get("BATCH_SIZE", 50))
Expand Down
20 changes: 10 additions & 10 deletions llm/server/server/engine/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(self, args):

self.config = Config()
self.model_cfg = self.config.get_model_config()
self.is_speculate_decoding = self.model_cfg.get("speculate_method") is not None
self.format_print_configuration()

self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
Expand All @@ -63,16 +64,16 @@ def __init__(self, args):
self.cache_kvs = {}
self.init_inputs()

# whether use speculate decoding
if self.config.speculate_method is not None:
if self.config.speculate_method == "inference_with_reference":
if self.is_speculate_decoding:
logger.info(f'Using speculating decoding, method: {self.model_cfg["speculate_method"]}.')
if self.model_cfg["speculate_method"] == "inference_with_reference":
self.proposer = InferenceWithReferenceProposer(
self.model_cfg["speculate_max_draft_token_num"],
self.model_cfg["speculate_max_ngram_size"],
self.args.max_batch_size,
self.args.max_seq_len)
else:
raise NotImplementedError(f'Not support {self.config.speculate_method}, only support inference_with_reference now.')
raise NotImplementedError(f'Not support {self.model_cfg["speculate_method"]}, only support inference_with_reference now.')
else:
self.proposer = None

Expand Down Expand Up @@ -261,7 +262,7 @@ def init_inputs(self):
shape=[1], fill_value=self.free_list_len, dtype="int32")

# speculate decoding input
if self.config.speculate_method is not None:
if self.is_speculate_decoding:
self.share_inputs["accept_tokens"] = paddle.full(
shape=[self.args.max_batch_size, self.model_cfg["speculate_max_draft_token_num"] + 1], fill_value=0, dtype="int64"
)
Expand Down Expand Up @@ -315,16 +316,15 @@ def dy_input_preprocess(self, tasks):
self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
task['block_tables'], dtype="int32")

if self.proposer is not None:
if self.config.speculate_method == "inference_with_reference":
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1])
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]])
if self.is_speculate_decoding:
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.model_cfg["speculate_max_draft_token_num"] + 1])
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.model_cfg["speculate_max_draft_token_num"]])

def step_cuda(self, seq_lens_this_time):
"""
step cuda
"""
if self.config.speculate_method is None:
if not self.is_speculate_decoding:
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
self.share_inputs['step_seq_lens_encoder'],
self.share_inputs['seq_lens_encoder'],
Expand Down
13 changes: 6 additions & 7 deletions llm/server/server/engine/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import numpy as np
from paddlenlp_ops import get_output, speculate_get_output
from server.utils import datetime_diff, model_server_logger, monitor_logger
from paddlenlp.utils.env import MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ

SPECULATE_MAX_BSZ = 256
MAX_DRAFT_TOKEN_NUM = 6

class TokenProcessor(object):
"""
Expand All @@ -40,8 +39,9 @@ def __init__(self, cfg):

self.tokens_counter = Counter()

if self.cfg.speculate_method is not None:
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKEN_NUM + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
self.is_speculate_decoding = self.cfg.get_model_config().get("speculate_method") is not None
if self.is_speculate_decoding:
self.output_tokens = paddle.full(shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2], fill_value=2, dtype="int64")
else:
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
self.worker = None
Expand Down Expand Up @@ -71,7 +71,7 @@ def run(self):
if self.worker is not None:
raise Exception("Worker is already running!")

if self.cfg.speculate_method is not None:
if self.is_speculate_decoding:
self.worker = threading.Thread(target=self.process_speculate_results, args=())
else:
self.worker = threading.Thread(target=self.process_sampling_results, args=())
Expand Down Expand Up @@ -302,7 +302,6 @@ def _process_speculate_output(self):
batch post-processing function
"""
tokens = self.output_tokens.numpy()
model_server_logger.info(f"speculate_result tokens: {self.output_tokens.tolist()}")
batch = self.output_tokens[1]
output_token_msg_id = int(self.output_tokens[0])
accept_num = tokens[2 : batch + 2]
Expand All @@ -317,7 +316,7 @@ def _process_speculate_output(self):
if self.resource_manager.stop_flags[i]:
continue

token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKEN_NUM + accept_num[i]].tolist()
token_ids = tokens[2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS: 2 + SPECULATE_MAX_BSZ + i * MAX_DRAFT_TOKENS + accept_num[i]].tolist()
# 跳过非法token
if len(token_ids) == 0 or token_ids[-1] == 0:
continue
Expand Down

0 comments on commit 62d4ec2

Please sign in to comment.