diff --git a/autogen/oai/chat_completion_proxy.py b/autogen/oai/chat_completion_proxy.py index 5a1d105715f..a8f6c443400 100644 --- a/autogen/oai/chat_completion_proxy.py +++ b/autogen/oai/chat_completion_proxy.py @@ -3,29 +3,30 @@ from openai.openai_object import OpenAIObject -class ChatCompletionProxy(): + +class ChatCompletionProxy: @classmethod def _prompt_tokens(cls, messages): # Get the encoding for OpenAI's "cl100k_base" model encoding = tiktoken.get_encoding("cl100k_base") - + # Calculate the total number of tokens in the prompt # by iterating over each message in the 'messages' list, # encoding its content, and summing up the token counts. - return sum([len(encoding.encode(msg['content'])) for msg in messages]) + return sum([len(encoding.encode(msg["content"])) for msg in messages]) @classmethod def create(cls, *args, **kwargs): # Check if streaming is enabled in the function arguments - if kwargs.get('stream', False) and 'functions' not in kwargs: + if kwargs.get("stream", False) and "functions" not in kwargs: # Prepare response array based on parameter 'n' - response_contents = [""] * kwargs.get('n', 1) - finish_reasons = [""] * kwargs.get('n', 1) + response_contents = [""] * kwargs.get("n", 1) + finish_reasons = [""] * kwargs.get("n", 1) completion_tokens = 0 - + # Set the terminal text color to green for better visibility - print("\033[32m", end='') - + print("\033[32m", end="") + # Send the chat completion request to OpenAI's API and process the response in chunks for chunk in openai.ChatCompletion.create(*args, **kwargs): if chunk["choices"]: @@ -33,43 +34,42 @@ def create(cls, *args, **kwargs): content = choice.get("delta", {}).get("content") # If content is present, print it to the terminal and update response variables if content is not None: - print(content, end='', flush=True) + print(content, end="", flush=True) response_contents[choice.index] += content finish_reasons[choice.index] = choice.get("finish_reasons", None) completion_tokens += 1 else: print() - + # Reset the terminal text color print("\033[0m\n") - + # Prepare the final response object based on the accumulated data prompt_tokens = cls._prompt_tokens(kwargs["messages"]) response = OpenAIObject() - response.id = chunk['id'] - response.object = 'chat.completion' - response.created = chunk['created'] - response.model = chunk['model'] + response.id = chunk["id"] + response.object = "chat.completion" + response.created = chunk["created"] + response.model = chunk["model"] response.choices = [] response.usage = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens + "total_tokens": prompt_tokens + completion_tokens, } for i in range(len(response_contents)): - response["choices"].append({ - "index": i, - "finish_reason": finish_reasons[i], - "message": { - 'role': 'assistant', - 'content': response_contents[i] + response["choices"].append( + { + "index": i, + "finish_reason": finish_reasons[i], + "message": {"role": "assistant", "content": response_contents[i]}, } - }) + ) else: # If streaming is not enabled, send a regular chat completion request # Ensure streaming is disabled - kwargs['stream'] = False + kwargs["stream"] = False response = openai.ChatCompletion.create(*args, **kwargs) - + # Return the final response object return response diff --git a/autogen/oai/completion.py b/autogen/oai/completion.py index 38d8cde4dee..18239842a32 100644 --- a/autogen/oai/completion.py +++ b/autogen/oai/completion.py @@ -208,7 +208,7 @@ def _get_response(cls, config: Dict, raise_on_ratelimit_or_timeout=False, use_ca cls._book_keeping(config, response) return response openai_completion = ( - ChatCompletionProxy # Support streaming for chat models + ChatCompletionProxy # Support streaming for chat models if config["model"].replace("gpt-35-turbo", "gpt-3.5-turbo") in cls.chat_models or issubclass(cls, ChatCompletion) else openai.Completion