Skip to content

Commit

Permalink
Add programatic streaming support, better openai api alignment, no lo…
Browse files Browse the repository at this point in the history
…nger auto choose aritrary stop text
  • Loading branch information
slundberg committed May 26, 2023
1 parent 8e1a42b commit 2f9ef9d
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 71 deletions.
86 changes: 79 additions & 7 deletions guidance/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self, text, llm=None, cache_seed=0, logprobs=None, silent='auto', a
self._executor = None # the ProgramExecutor object that is running the program
self._last_display_update = 0 # the last time we updated the display (used for throttling updates)
self._execute_complete = asyncio.Event() # fires when the program is done executing to resolve __await__
self._emit_stream_event = asyncio.Event() # fires when we need to emit a stream event
self._displaying = not self.silent # if we are displaying we need to update the display as we execute
self._displayed = False # marks if we have been displayed in the client yet
self._displaying_html = False # if we are displaying html (vs. text)
Expand Down Expand Up @@ -162,16 +163,19 @@ def _ipython_display_(self):


async def _await_finish_execute(self):
""" Used by self.__await__ to wait for the program to complete.
"""
"""Used by self.__await__ to wait for the program to complete."""
await self._execute_complete.wait() # wait for the program to finish executing
return self

def __await__(self):
return self._await_finish_execute().__await__()

def __aiter__(self):
"""Return an async iterator that yields the program in partial states as it is run."""
return self._stream_run_async()

def __call__(self, **kwargs):
""" Execute this program with the given variable values and return a new executed/executing program.
"""Execute this program with the given variable values and return a new executed/executing program.
Note that the returned program might not be fully executed if `stream=True`. When streaming you need to
use the python `await` keyword if you want to ensure the program is finished (note that is different than
Expand Down Expand Up @@ -214,7 +218,7 @@ def __call__(self, **kwargs):
# if we are not in async mode, we need to create a new event loop and run the program in it until it is done
else:

# apply the nested event loop patch if needed
# apply nested event loop patch if needed
try:
other_loop = asyncio.get_event_loop()
nest_asyncio.apply(other_loop)
Expand All @@ -223,12 +227,68 @@ def __call__(self, **kwargs):

loop = asyncio.new_event_loop()
loop.create_task(new_program.update_display.run()) # start the display updater
loop.run_until_complete(new_program.execute())

if new_program.stream:
return self._stream_run(loop, new_program)
else:
loop.run_until_complete(new_program.execute())

return new_program

def get(self, key, default=None):
"""Get the value of a variable by name."""
return self._variables.get(key, default)

def _stream_run(self, loop, new_program):
"""This feels a bit hacky at the moment. TODO: clean this up."""

# add the program execution to the event loop
execute_task = loop.create_task(new_program.execute())

# run the event loop until the program is done executing
while new_program._executor is not None:
try:
loop.run_until_complete(execute_task) # this will stop each time the program wants to emit a new state
except RuntimeError as e:
# we don't mind that the task is not yet done, we will restart the loop
if str(e) != "Event loop stopped before Future completed.":
raise e
if getattr(loop, "_stopping", False):
loop._stopping = False # clean up the stopping flag
if new_program._executor is not None and new_program._executor.executing:
try:
yield new_program
except GeneratorExit:
# this will cause the program to stop executing and finish as a valid partial execution
if new_program._executor.executing:
new_program._executor.executing = False
yield new_program

# cancel all tasks and close the loop
for task in asyncio.all_tasks(loop=loop):
task.cancel()
loop.run_until_complete(asyncio.sleep(0)) # give the loop a chance to cancel the tasks
loop.close() # we are done with the loop (note that the loop is already stopped)

async def _stream_run_async(self):

# run the event loop until the program is done executing
while self._executor is not None:
if self._executor.executing:
await self._emit_stream_event.wait()
self._emit_stream_event.clear()
try:
yield self
except GeneratorExit as e:
# this will cause the program to stop executing and finish as a valid partial execution
if self._executor.executing:
self._executor.executing = False
await self._execute_complete.wait()

raise e
yield self

def _update_display(self, last=False):
""" Updates the display with the current marked text after debouncing.
"""Updates the display with the current marked text after debouncing.
Parameters
----------
Expand All @@ -241,6 +301,18 @@ def _update_display(self, last=False):

log.debug(f"Updating display (last={last}, self._displaying={self._displaying}, self._comm={self._comm})")


if self.stream:
if self.async_mode:
# if we are streaming in async mode then we set the event to let the generator know it can yield
self._emit_stream_event.set()

else:
# if we are streaming not in async mode then we pause the event loop to let the generator
# that is controlling execution return (it will restart the event loop when it is ready)
if self._executor is not None:
asyncio.get_event_loop().stop()

# this is always called during execution, and we only want to update the display if we are displaying
if not self._displaying:
return
Expand Down
19 changes: 12 additions & 7 deletions guidance/library/_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

log = logging.getLogger(__name__)

async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_tokens=500, n=1,
async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_tokens=500, n=1, stream=None,
temperature=0.0, top_p=1.0, logprobs=None, pattern=None, hidden=False, list_append=False,
save_prompt=False, token_healing=None, _parser_context=None):
''' Use the LLM to generate a completion.
Expand Down Expand Up @@ -102,9 +102,9 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t
if next_text.startswith(end_tag):
stop = end_tag

# fall back to the next node's text
if stop is None:
stop = next_text
# fall back to the next node's text (this was too easy to accidentally trigger, so we disable it now)
# if stop is None:
# stop = next_text

if stop == "":
stop = None
Expand All @@ -116,8 +116,13 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t
else:
cache_seed = 0

# we can't stream batches right now
stream_generation = parser.program.stream if n == 1 else False
# set streaming default
if stream is None:
stream = parser.program.stream or parser.program._displaying or stop_regex is not None if n == 1 else False

# we can't stream batches right now TODO: fix this
assert not (stream and n > 1), "You can't stream batches of completions right now."
# stream_generation = parser.program.stream if n == 1 else False

# save the prompt if requested
if save_prompt:
Expand All @@ -132,7 +137,7 @@ async def gen(name=None, stop=None, stop_regex=None, save_stop_text=False, max_t
gen_obj = await parser.llm_session(
parser_prefix+prefix, stop=stop, stop_regex=stop_regex, max_tokens=max_tokens, n=n, pattern=pattern,
temperature=temperature, top_p=top_p, logprobs=logprobs, cache_seed=cache_seed, token_healing=token_healing,
echo=parser.program.logprobs is not None, stream=stream_generation, caching=parser.program.caching
echo=parser.program.logprobs is not None, stream=stream, caching=parser.program.caching
)

if n == 1:
Expand Down
16 changes: 11 additions & 5 deletions guidance/library/_geneach.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,15 @@ async def geneach(list_name, stop=None, max_iterations=100, min_iterations=0, nu
pos = len(parser.prefix)

# add the join string if we are not on the first iteration
if len(data) > 0 and join != "":
if i > 0 and join != "":
partial_output(join)

await parser.visit(block_content[0]) # fills out parser.prefix
block_variables = parser.variable_stack.pop()["this"]
data.append(block_variables)

# update the list variable (we do this each time we get a new item so that streaming works)
parser.set_variable(list_name, parser.get_variable(list_name, default_value=[]) + [block_variables])

if hidden:
# new_content = parser.prefix[pos:]
parser.reset_prefix(pos)
Expand Down Expand Up @@ -190,14 +193,17 @@ async def geneach(list_name, stop=None, max_iterations=100, min_iterations=0, nu
# get the variables that were generated
match_dict = m.groupdict()
if "this" in match_dict:
data.append(match_dict["this"])
next_item = match_dict["this"]
else:
d = {}
for k in match_dict:
k = _unescape_group_name(k)
if k.startswith("this."):
d[k[5:]] = match_dict[k].strip()
data.append(d)
next_item = d

# update the list variable (we do this each time we get a new item so that streaming works)
parser.set_variable(list_name, parser.get_variable(list_name, default_value=[]) + [next_item])

# recreate the output string with format markers added
item_out = re.sub(
Expand All @@ -221,7 +227,7 @@ async def geneach(list_name, stop=None, max_iterations=100, min_iterations=0, nu
partial_output("{{!--GMARKER_each$$--}}") # end marker

# parser.get_variable(list, [])
parser.set_variable(list_name, parser.get_variable(list_name, default_value=[]) + data)
#parser.set_variable(list_name, parser.get_variable(list_name, default_value=[]) + data)

# if we have stopped executing, we need to add the loop to the output so it can be executed later
if not parser.executing:
Expand Down
Loading

0 comments on commit 2f9ef9d

Please sign in to comment.