Skip to content

Commit

Permalink
Enable streaming support for openai v1 (#597)
Browse files Browse the repository at this point in the history
* Enable streaming support for openai v1

* Added tests for openai client streaming

* Fix test_completion_stream
  • Loading branch information
Alvaromah authored Nov 11, 2023
1 parent 8d225db commit 849feda
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 2 deletions.
68 changes: 66 additions & 2 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
from flaml.automl.logger import logger_formatter

from autogen.oai.openai_utils import get_key
from autogen.token_count_utils import count_token

try:
from openai import OpenAI, APIError
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion import Completion
from openai.types.completion_usage import CompletionUsage
import diskcache

ERROR = None
Expand Down Expand Up @@ -237,9 +240,8 @@ def yes_or_no_filter(context, response):
response.pass_filter = pass_filter
# TODO: add response.cost
return response
completions = client.chat.completions if "messages" in params else client.completions
try:
response = completions.create(**params)
response = self._completions_create(client, params)
except APIError:
logger.debug(f"config {i} failed", exc_info=1)
if i == last:
Expand All @@ -250,6 +252,68 @@ def yes_or_no_filter(context, response):
cache.set(key, response)
return response

def _completions_create(self, client, params):
completions = client.chat.completions if "messages" in params else client.completions
# If streaming is enabled, has messages, and does not have functions, then
# iterate over the chunks of the response
if params.get("stream", False) and "messages" in params and "functions" not in params:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0

# Set the terminal text color to green
print("\033[32m", end="")

# Send the chat completion request to OpenAI's API and process the response in chunks
for chunk in completions.create(**params):
if chunk.choices:
for choice in chunk.choices:
content = choice.delta.content
finish_reasons[choice.index] = choice.finish_reason
# If content is present, print it to the terminal and update response variables
if content is not None:
print(content, end="", flush=True)
response_contents[choice.index] += content
completion_tokens += 1
else:
print()

# Reset the terminal text color
print("\033[0m\n")

# Prepare the final ChatCompletion object based on the accumulated data
model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API
prompt_tokens = count_token(params["messages"], model)
response = ChatCompletion(
id=chunk.id,
model=chunk.model,
created=chunk.created,
object="chat.completion",
choices=[],
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
for i in range(len(response_contents)):
response.choices.append(
Choice(
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
role="assistant", content=response_contents[i], function_call=None
),
)
)
else:
# If streaming is not enabled or using functions, send a regular chat completion request
# Functions are not supported, so ensure streaming is disabled
params = params.copy()
params["stream"] = False
response = completions.create(**params)
return response

@classmethod
def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
"""Extract the text or function calls from a completion or chat response.
Expand Down
86 changes: 86 additions & 0 deletions test/oai/test_client_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
from test_utils import OAI_CONFIG_LIST, KEY_LOC

try:
from openai import OpenAI
except ImportError:
skip = True
else:
skip = False


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_aoai_chat_completion_stream():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "2+2="}], seed=None, stream=True)
print(response)
print(client.extract_text_or_function_call(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_completion_stream():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"model": ["gpt-3.5-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}], seed=None, stream=True)
print(response)
print(client.extract_text_or_function_call(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_functions_stream():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"model": ["gpt-3.5-turbo"]},
)
functions = [
{
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
},
"required": ["location"],
},
},
]
client = OpenAIWrapper(config_list=config_list)
response = client.create(
messages=[{"role": "user", "content": "What's the weather like today in San Francisco?"}],
functions=functions,
seed=None,
stream=True,
)
print(response)
print(client.extract_text_or_function_call(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_completion_stream():
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", seed=None, stream=True)
print(response)
print(client.extract_text_or_function_call(response))


if __name__ == "__main__":
test_aoai_chat_completion_stream()
test_chat_completion_stream()
test_chat_functions_stream()
test_completion_stream()

0 comments on commit 849feda

Please sign in to comment.