Skip to content

Commit

Permalink
support spec_tokens=512
Browse files Browse the repository at this point in the history
  • Loading branch information
parasol-aser committed Jan 19, 2024
1 parent fb44aa5 commit d270252
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
3 changes: 2 additions & 1 deletion examples/quick_start/openai_example_speculative.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sglang import function, gen, set_default_backend, OpenAI
@function

@function(spec_tokens=512)
def example(s):
s += "Construct a character. Here is an example:\n"
s += "Name: Steve Jobs. Birthday: February 24, 1955. Job: Apple CEO.\n"
Expand Down
10 changes: 6 additions & 4 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
SglSelect,
)


def function(func: Callable):
return SglFunction(func)

def function(func=None,**bind_arguments):
if func:
return SglFunction(func, bind_arguments)
def decorator(func):
return SglFunction(func, bind_arguments)
return decorator

def Runtime(*args, **kwargs):
# Avoid importing unnecessary dependency
Expand Down
11 changes: 9 additions & 2 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ def run_program(
program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
):
assert backend is not None, "Please specify a backend"
func_kwargs.update(program.bind_arguments)
# func_kwargs.update(program.bind_arguments)
stream_executor = StreamExecutor(
backend, func_kwargs, default_sampling_para, chat_template=None, stream=stream
)
stream_executor.bind_arguments = program.bind_arguments
state = ProgramState(stream_executor)

if stream:
Expand Down Expand Up @@ -170,7 +171,8 @@ def __init__(

# For speculative execution
self.last_comp = "" # The last completion message returned
self.speculate_text_prefix = "" # The speculated text prefix
self.speculate_text_prefix = "" # The speculated text prefix
self.bind_arguments = {}

# For chat
self.messages_ = [] # The messages in the OpenAI API format
Expand Down Expand Up @@ -361,6 +363,11 @@ def _execute_gen(self, expr: SglGen):
self.variables[self.last_name] = last_comp
meta_info = {}
else:
if self.last_comp == "" and 'spec_tokens' in self.bind_arguments:
spec_tokens_value = self.bind_arguments['spec_tokens']
sampling_params.max_new_tokens = spec_tokens_value
sampling_params.stop = None

comp, meta_info = self.backend.generate(
self, sampling_params=sampling_params
)
Expand Down

0 comments on commit d270252

Please sign in to comment.