From 56aed2d3d19d03c9d48d2ae22f94a4ae2cb16ba4 Mon Sep 17 00:00:00 2001 From: Davor Runje Date: Thu, 11 Jan 2024 05:34:51 +0100 Subject: [PATCH] Added support for streaming tool calls (#1184) * added support for streaming tool calls * bug fix: removed tmp assert --------- Co-authored-by: Chi Wang --- autogen/oai/client.py | 301 +++++++++++++++++++++++++-------- autogen/oai/openai_utils.py | 4 +- test/oai/test_client_stream.py | 214 ++++++++++++++++++++++- 3 files changed, 445 insertions(+), 74 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index feaf56a7609..1bdfd835d1e 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -2,21 +2,36 @@ import os import sys -from typing import List, Optional, Dict, Callable, Union +from typing import Any, List, Optional, Dict, Callable, Tuple, Union import logging import inspect from flaml.automl.logger import logger_formatter -from pydantic import ValidationError + +from pydantic import BaseModel + +from autogen.oai import completion from autogen.oai.openai_utils import get_key, OAI_PRICE1K from autogen.token_count_utils import count_token +from autogen._pydantic import model_dump TOOL_ENABLED = False try: import openai +except ImportError: + ERROR: Optional[ImportError] = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.") + OpenAI = object +else: + # raises exception if openai>=1 is installed and something is wrong with imports from openai import OpenAI, APIError, __version__ as OPENAIVERSION + from openai.resources import Completions from openai.types.chat import ChatCompletion - from openai.types.chat.chat_completion import ChatCompletionMessage, Choice + from openai.types.chat.chat_completion import ChatCompletionMessage, Choice # type: ignore [attr-defined] + from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, + ChoiceDeltaFunctionCall, + ) from openai.types.completion import Completion from openai.types.completion_usage import CompletionUsage import diskcache @@ -24,9 +39,7 @@ if openai.__version__ >= "1.1.0": TOOL_ENABLED = True ERROR = None -except ImportError: - ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.") - OpenAI = object + logger = logging.getLogger(__name__) if not logger.handlers: # Add the console handler. @@ -41,10 +54,10 @@ class OpenAIWrapper: cache_path_root: str = ".cache" extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"} openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs) - total_usage_summary: Dict = None - actual_usage_summary: Dict = None + total_usage_summary: Optional[Dict[str, Any]] = None + actual_usage_summary: Optional[Dict[str, Any]] = None - def __init__(self, *, config_list: List[Dict] = None, **base_config): + def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base_config: Any): """ Args: config_list: a list of config dicts to override the base_config. @@ -81,7 +94,9 @@ def __init__(self, *, config_list: List[Dict] = None, **base_config): logger.warning("openai client was provided with an empty config_list, which may not be intended.") if config_list: config_list = [config.copy() for config in config_list] # make a copy before modifying - self._clients = [self._client(config, openai_config) for config in config_list] # could modify the config + self._clients: List[OpenAI] = [ + self._client(config, openai_config) for config in config_list + ] # could modify the config self._config_list = [ {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} for config in config_list @@ -90,7 +105,9 @@ def __init__(self, *, config_list: List[Dict] = None, **base_config): self._clients = [self._client(extra_kwargs, openai_config)] self._config_list = [extra_kwargs] - def _process_for_azure(self, config: Dict, extra_kwargs: Dict, segment: str = "default"): + def _process_for_azure( + self, config: Dict[str, Any], extra_kwargs: Dict[str, Any], segment: str = "default" + ) -> None: # deal with api_version query_segment = f"{segment}_query" headers_segment = f"{segment}_headers" @@ -123,20 +140,20 @@ def _process_for_azure(self, config: Dict, extra_kwargs: Dict, segment: str = "d if not base_url.endswith(suffix): config["base_url"] += suffix[1:] if base_url.endswith("/") else suffix - def _separate_openai_config(self, config): + def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Separate the config into openai_config and extra_kwargs.""" openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs} extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs} self._process_for_azure(openai_config, extra_kwargs) return openai_config, extra_kwargs - def _separate_create_config(self, config): + def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Separate the config into create_config and extra_kwargs.""" create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs} extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs} return create_config, extra_kwargs - def _client(self, config, openai_config): + def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI: """Create a client with the given config to override openai_config, after removing extra kwargs. """ @@ -148,21 +165,21 @@ def _client(self, config, openai_config): @classmethod def instantiate( cls, - template: str | Callable | None, - context: Optional[Dict] = None, + template: Optional[Union[str, Callable[[Dict[str, Any]], str]]], + context: Optional[Dict[str, Any]] = None, allow_format_str_template: Optional[bool] = False, - ): + ) -> Optional[str]: if not context or template is None: - return template + return template # type: ignore [return-value] if isinstance(template, str): return template.format(**context) if allow_format_str_template else template return template(context) - def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict: + def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: Dict[str, Any]) -> Dict[str, Any]: """Prime the create_config with additional_kwargs.""" # Validate the config - prompt = create_config.get("prompt") - messages = create_config.get("messages") + prompt: Optional[str] = create_config.get("prompt") + messages: Optional[List[Dict[str, Any]]] = create_config.get("messages") if (prompt is None) == (messages is None): raise ValueError("Either prompt or messages should be in create config but not both.") context = extra_kwargs.get("context") @@ -185,11 +202,11 @@ def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> D } if m.get("content") else m - for m in messages + for m in messages # type: ignore [union-attr] ] return params - def create(self, **config): + def create(self, **config: Any) -> ChatCompletion: """Make a completion for a given config using openai's clients. Besides the kwargs allowed in openai's client, we allow the following additional kwargs. The config in each client will be overridden by the config. @@ -239,11 +256,11 @@ def yes_or_no_filter(context, response): with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache: # Try to get the response from cache key = get_key(params) - response = cache.get(key, None) + response: ChatCompletion = cache.get(key, None) if response is not None: try: - response.cost + response.cost # type: ignore [attr-defined] except AttributeError: # update attribute if cost is not calculated response.cost = self.cost(response) @@ -264,7 +281,7 @@ def yes_or_no_filter(context, response): if error_code == "content_filter": # raise the error for content_filter raise - logger.debug(f"config {i} failed", exc_info=1) + logger.debug(f"config {i} failed", exc_info=True) if i == last: raise else: @@ -284,9 +301,129 @@ def yes_or_no_filter(context, response): response.pass_filter = pass_filter return response continue # filter is not passed; try the next config + raise RuntimeError("Should not reach here.") + + @staticmethod + def _update_dict_from_chunk(chunk: BaseModel, d: Dict[str, Any], field: str) -> int: + """Update the dict from the chunk. + + Reads `chunk.field` and if present updates `d[field]` accordingly. + + Args: + chunk: The chunk. + d: The dict to be updated in place. + field: The field. + + Returns: + The updated dict. + + """ + completion_tokens = 0 + assert isinstance(d, dict), d + if hasattr(chunk, field) and getattr(chunk, field) is not None: + new_value = getattr(chunk, field) + if isinstance(new_value, list) or isinstance(new_value, dict): + raise NotImplementedError( + f"Field {field} is a list or dict, which is currently not supported. " + "Only string and numbers are supported." + ) + if field not in d: + d[field] = "" + if isinstance(new_value, str): + d[field] += getattr(chunk, field) + else: + d[field] = new_value + completion_tokens = 1 + + return completion_tokens + + @staticmethod + def _update_function_call_from_chunk( + function_call_chunk: Union[ChoiceDeltaToolCallFunction, ChoiceDeltaFunctionCall], + full_function_call: Optional[Dict[str, Any]], + completion_tokens: int, + ) -> Tuple[Dict[str, Any], int]: + """Update the function call from the chunk. + + Args: + function_call_chunk: The function call chunk. + full_function_call: The full function call. + completion_tokens: The number of completion tokens. + + Returns: + The updated full function call and the updated number of completion tokens. + + """ + # Handle function call + if function_call_chunk: + if full_function_call is None: + full_function_call = {} + for field in ["name", "arguments"]: + completion_tokens += OpenAIWrapper._update_dict_from_chunk( + function_call_chunk, full_function_call, field + ) - def _completions_create(self, client, params): - completions = client.chat.completions if "messages" in params else client.completions + if full_function_call: + return full_function_call, completion_tokens + else: + raise RuntimeError("Function call is not found, this should not happen.") + + @staticmethod + def _update_tool_calls_from_chunk( + tool_calls_chunk: ChoiceDeltaToolCall, + full_tool_call: Optional[Dict[str, Any]], + completion_tokens: int, + ) -> Tuple[Dict[str, Any], int]: + """Update the tool call from the chunk. + + Args: + tool_call_chunk: The tool call chunk. + full_tool_call: The full tool call. + completion_tokens: The number of completion tokens. + + Returns: + The updated full tool call and the updated number of completion tokens. + + """ + # future proofing for when tool calls other than function calls are supported + if tool_calls_chunk.type and tool_calls_chunk.type != "function": + raise NotImplementedError( + f"Tool call type {tool_calls_chunk.type} is currently not supported. " + "Only function calls are supported." + ) + + # Handle tool call + assert full_tool_call is None or isinstance(full_tool_call, dict), full_tool_call + if tool_calls_chunk: + if full_tool_call is None: + full_tool_call = {} + for field in ["index", "id", "type"]: + completion_tokens += OpenAIWrapper._update_dict_from_chunk(tool_calls_chunk, full_tool_call, field) + + if hasattr(tool_calls_chunk, "function") and tool_calls_chunk.function: + if "function" not in full_tool_call: + full_tool_call["function"] = None + + full_tool_call["function"], completion_tokens = OpenAIWrapper._update_function_call_from_chunk( + tool_calls_chunk.function, full_tool_call["function"], completion_tokens + ) + + if full_tool_call: + return full_tool_call, completion_tokens + else: + raise RuntimeError("Tool call is not found, this should not happen.") + + def _completions_create(self, client: OpenAI, params: Dict[str, Any]) -> ChatCompletion: + """Create a completion for a given config using openai's client. + + Args: + client: The openai client. + params: The params for the completion. + + Returns: + The completion. + """ + completions: Completions = client.chat.completions if "messages" in params else client.completions # type: ignore [attr-defined] # If streaming is enabled and has messages, then iterate over the chunks of the response. if params.get("stream", False) and "messages" in params: response_contents = [""] * params.get("n", 1) @@ -297,31 +434,52 @@ def _completions_create(self, client, params): print("\033[32m", end="") # Prepare for potential function call - full_function_call = None + full_function_call: Optional[Dict[str, Any]] = None + full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None + # Send the chat completion request to OpenAI's API and process the response in chunks for chunk in completions.create(**params): if chunk.choices: for choice in chunk.choices: content = choice.delta.content - function_call_chunk = choice.delta.function_call + tool_calls_chunks = choice.delta.tool_calls finish_reasons[choice.index] = choice.finish_reason + # todo: remove this after function calls are removed from the API + # the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail + # begin block + function_call_chunk = ( + choice.delta.function_call if hasattr(choice.delta, "function_call") else None + ) # Handle function call if function_call_chunk: - if hasattr(function_call_chunk, "name") and function_call_chunk.name: - if full_function_call is None: - full_function_call = {"name": "", "arguments": ""} - full_function_call["name"] += function_call_chunk.name - completion_tokens += 1 - if hasattr(function_call_chunk, "arguments") and function_call_chunk.arguments: - full_function_call["arguments"] += function_call_chunk.arguments - completion_tokens += 1 - if choice.finish_reason == "function_call": - # Need something here? I don't think so. - pass - if not content: - continue - # End handle function call + # Handle function call + if function_call_chunk: + full_function_call, completion_tokens = self._update_function_call_from_chunk( + function_call_chunk, full_function_call, completion_tokens + ) + if not content: + continue + # end block + + # Handle tool calls + if tool_calls_chunks: + for tool_calls_chunk in tool_calls_chunks: + # the current tool call to be reconstructed + ix = tool_calls_chunk.index + if full_tool_calls is None: + full_tool_calls = [] + if ix >= len(full_tool_calls): + # in case ix is not sequential + full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1) + + full_tool_calls[ix], completion_tokens = self._update_tool_calls_from_chunk( + tool_calls_chunk, full_tool_calls[ix], completion_tokens + ) + if not content: + continue + + # End handle tool calls # If content is present, print it to the terminal and update response variables if content is not None: @@ -329,7 +487,8 @@ def _completions_create(self, client, params): response_contents[choice.index] += content completion_tokens += 1 else: - print() + # print() + pass # Reset the terminal text color print("\033[0m\n") @@ -356,17 +515,23 @@ def _completions_create(self, client, params): index=i, finish_reason=finish_reasons[i], message=ChatCompletionMessage( - role="assistant", content=response_contents[i], function_call=full_function_call + role="assistant", + content=response_contents[i], + function_call=full_function_call, + tool_calls=full_tool_calls, ), logprobs=None, ) else: # OpenAI versions below 1.5.0 - choice = Choice( + choice = Choice( # type: ignore [call-arg] index=i, finish_reason=finish_reasons[i], message=ChatCompletionMessage( - role="assistant", content=response_contents[i], function_call=full_function_call + role="assistant", + content=response_contents[i], + function_call=full_function_call, + tool_calls=full_tool_calls, ), ) @@ -379,7 +544,7 @@ def _completions_create(self, client, params): return response - def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None: + def _update_usage_summary(self, response: Union[ChatCompletion, Completion], use_cache: bool) -> None: """Update the usage summary. Usage is calculated no matter filter is passed or not. @@ -391,17 +556,17 @@ def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache usage.completion_tokens = 0 if usage.completion_tokens is None else usage.completion_tokens usage.total_tokens = 0 if usage.total_tokens is None else usage.total_tokens except (AttributeError, AssertionError): - logger.debug("Usage attribute is not found in the response.", exc_info=1) + logger.debug("Usage attribute is not found in the response.", exc_info=True) return - def update_usage(usage_summary): + def update_usage(usage_summary: Optional[Dict[str, Any]]) -> Dict[str, Any]: if usage_summary is None: - usage_summary = {"total_cost": response.cost} + usage_summary = {"total_cost": response.cost} # type: ignore [union-attr] else: - usage_summary["total_cost"] += response.cost + usage_summary["total_cost"] += response.cost # type: ignore [union-attr] usage_summary[response.model] = { - "cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost, + "cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost, # type: ignore [union-attr] "prompt_tokens": usage_summary.get(response.model, {}).get("prompt_tokens", 0) + usage.prompt_tokens, "completion_tokens": usage_summary.get(response.model, {}).get("completion_tokens", 0) + usage.completion_tokens, @@ -416,7 +581,7 @@ def update_usage(usage_summary): def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None: """Print the usage summary.""" - def print_usage(usage_summary, usage_type="total"): + def print_usage(usage_summary: Optional[Dict[str, Any]], usage_type: str = "total") -> None: word_from_type = "including" if usage_type == "total" else "excluding" if usage_summary is None: print("No actual cost incurred (all completions are using cache).", flush=True) @@ -475,20 +640,20 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float: model = response.model if model not in OAI_PRICE1K: # TODO: add logging to warn that the model is not found - logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=1) + logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True) return 0 - n_input_tokens = response.usage.prompt_tokens - n_output_tokens = response.usage.completion_tokens + n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr] + n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr] tmp_price1K = OAI_PRICE1K[model] # First value is input token rate, second value is output token rate if isinstance(tmp_price1K, tuple): - return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 - return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 + return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return] + return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator] @classmethod def extract_text_or_completion_object( - cls, response: ChatCompletion | Completion + cls, response: Union[ChatCompletion, Completion] ) -> Union[List[str], List[ChatCompletionMessage]]: """Extract the text or ChatCompletion objects from a completion or chat response. @@ -500,18 +665,18 @@ def extract_text_or_completion_object( """ choices = response.choices if isinstance(response, Completion): - return [choice.text for choice in choices] + return [choice.text for choice in choices] # type: ignore [union-attr] if TOOL_ENABLED: - return [ - choice.message - if choice.message.function_call is not None or choice.message.tool_calls is not None - else choice.message.content + return [ # type: ignore [return-value] + choice.message # type: ignore [union-attr] + if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] + else choice.message.content # type: ignore [union-attr] for choice in choices ] else: - return [ - choice.message if choice.message.function_call is not None else choice.message.content + return [ # type: ignore [return-value] + choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr] for choice in choices ] diff --git a/autogen/oai/openai_utils.py b/autogen/oai/openai_utils.py index 6e9bce11f12..525a1027cc0 100644 --- a/autogen/oai/openai_utils.py +++ b/autogen/oai/openai_utils.py @@ -3,7 +3,7 @@ import os import tempfile from pathlib import Path -from typing import Dict, List, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union from dotenv import find_dotenv, load_dotenv @@ -50,7 +50,7 @@ } -def get_key(config): +def get_key(config: Dict[str, Any]) -> str: """Get a unique identifier of a configuration. Args: diff --git a/test/oai/test_client_stream.py b/test/oai/test_client_stream.py index 284df095466..6a20c4ffa21 100644 --- a/test/oai/test_client_stream.py +++ b/test/oai/test_client_stream.py @@ -1,3 +1,6 @@ +import json +from typing import Any, Dict, List, Literal, Optional, Union +from unittest.mock import MagicMock import pytest from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai import sys @@ -13,12 +16,21 @@ else: skip = False or skip_openai + # raises exception if openai>=1 is installed and something is wrong with imports + # otherwise the test will be skipped + from openai.types.chat.chat_completion_chunk import ( + ChoiceDeltaFunctionCall, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, + ) + from openai.types.chat.chat_completion import ChatCompletionMessage # type: ignore [attr-defined] + KEY_LOC = "notebook" OAI_CONFIG_LIST = "OAI_CONFIG_LIST" @pytest.mark.skipif(skip, reason="openai>=1 not installed") -def test_aoai_chat_completion_stream(): +def test_aoai_chat_completion_stream() -> None: config_list = config_list_from_json( env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, @@ -31,7 +43,7 @@ def test_aoai_chat_completion_stream(): @pytest.mark.skipif(skip, reason="openai>=1 not installed") -def test_chat_completion_stream(): +def test_chat_completion_stream() -> None: config_list = config_list_from_json( env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, @@ -43,8 +55,147 @@ def test_chat_completion_stream(): print(client.extract_text_or_completion_object(response)) +# no need for OpenAI, works with any model +def test__update_dict_from_chunk() -> None: + # dictionaries and lists are not supported + mock = MagicMock() + empty_collections: List[Union[List[Any], Dict[str, Any]]] = [{}, []] + for c in empty_collections: + mock.c = c + with pytest.raises(NotImplementedError): + OpenAIWrapper._update_dict_from_chunk(mock, {}, "c") + + org_d: Dict[str, Any] = {} + for i, v in enumerate([0, 1, False, True, 0.0, 1.0]): + field = "abcedfghijklmnopqrstuvwxyz"[i] + setattr(mock, field, v) + + d = org_d.copy() + OpenAIWrapper._update_dict_from_chunk(mock, d, field) + + org_d[field] = v + assert d == org_d + + mock.s = "beginning" + OpenAIWrapper._update_dict_from_chunk(mock, d, "s") + assert d["s"] == "beginning" + + mock.s = " and" + OpenAIWrapper._update_dict_from_chunk(mock, d, "s") + assert d["s"] == "beginning and" + + mock.s = " end" + OpenAIWrapper._update_dict_from_chunk(mock, d, "s") + assert d["s"] == "beginning and end" + + +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test__update_function_call_from_chunk() -> None: + function_call_chunks = [ + ChoiceDeltaFunctionCall(arguments=None, name="get_current_weather"), + ChoiceDeltaFunctionCall(arguments='{"', name=None), + ChoiceDeltaFunctionCall(arguments="location", name=None), + ChoiceDeltaFunctionCall(arguments='":"', name=None), + ChoiceDeltaFunctionCall(arguments="San", name=None), + ChoiceDeltaFunctionCall(arguments=" Francisco", name=None), + ChoiceDeltaFunctionCall(arguments='"}', name=None), + ] + expected = {"name": "get_current_weather", "arguments": '{"location":"San Francisco"}'} + + full_function_call = None + completion_tokens = 0 + for function_call_chunk in function_call_chunks: + # print(f"{function_call_chunk=}") + full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk( + function_call_chunk=function_call_chunk, + full_function_call=full_function_call, + completion_tokens=completion_tokens, + ) + + print(f"{full_function_call=}") + print(f"{completion_tokens=}") + + assert full_function_call == expected + assert completion_tokens == len(function_call_chunks) + + ChatCompletionMessage(role="assistant", function_call=full_function_call, content=None) + + @pytest.mark.skipif(skip, reason="openai>=1 not installed") -def test_chat_functions_stream(): +def test__update_tool_calls_from_chunk() -> None: + tool_calls_chunks = [ + ChoiceDeltaToolCall( + index=0, + id="call_D2HOWGMekmkxXu9Ix3DUqJRv", + function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"), + type="function", + ), + ChoiceDeltaToolCall( + index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None + ), + ChoiceDeltaToolCall( + index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None + ), + ChoiceDeltaToolCall( + index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "S', name=None), type=None + ), + ChoiceDeltaToolCall( + index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="an F", name=None), type=None + ), + ChoiceDeltaToolCall( + index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="ranci", name=None), type=None + ), + ChoiceDeltaToolCall( + index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments="sco, C", name=None), type=None + ), + ChoiceDeltaToolCall( + index=0, id=None, function=ChoiceDeltaToolCallFunction(arguments='A"}', name=None), type=None + ), + ChoiceDeltaToolCall( + index=1, + id="call_22HgJep4nwoKU3UOr96xaLmd", + function=ChoiceDeltaToolCallFunction(arguments="", name="get_current_weather"), + type="function", + ), + ChoiceDeltaToolCall( + index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='{"lo', name=None), type=None + ), + ChoiceDeltaToolCall( + index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="catio", name=None), type=None + ), + ChoiceDeltaToolCall( + index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='n": "N', name=None), type=None + ), + ChoiceDeltaToolCall( + index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ew Y", name=None), type=None + ), + ChoiceDeltaToolCall( + index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments="ork, ", name=None), type=None + ), + ChoiceDeltaToolCall( + index=1, id=None, function=ChoiceDeltaToolCallFunction(arguments='NY"}', name=None), type=None + ), + ] + + full_tool_calls: List[Optional[Dict[str, Any]]] = [None, None] + completion_tokens = 0 + for tool_calls_chunk in tool_calls_chunks: + index = tool_calls_chunk.index + full_tool_calls[index], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk( + tool_calls_chunk=tool_calls_chunk, + full_tool_call=full_tool_calls[index], + completion_tokens=completion_tokens, + ) + + print(f"{full_tool_calls=}") + print(f"{completion_tokens=}") + + ChatCompletionMessage(role="assistant", tool_calls=full_tool_calls, content=None) + + +# todo: remove when OpenAI removes functions from the API +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_chat_functions_stream() -> None: config_list = config_list_from_json( env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, @@ -76,8 +227,63 @@ def test_chat_functions_stream(): print(client.extract_text_or_completion_object(response)) +# test for tool support instead of the deprecated function calls +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_chat_tools_stream() -> None: + config_list = config_list_from_json( + env_or_file=OAI_CONFIG_LIST, + file_location=KEY_LOC, + filter_dict={"model": ["gpt-3.5-turbo", "gpt-35-turbo"]}, + ) + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + }, + "required": ["location"], + }, + }, + }, + ] + print(f"{config_list=}") + client = OpenAIWrapper(config_list=config_list) + response = client.create( + # the intention is to trigger two tool invocations as a response to a single message + messages=[{"role": "user", "content": "What's the weather like today in San Francisco and New York?"}], + tools=tools, + stream=True, + ) + print(f"{response=}") + print(f"{type(response)=}") + print(f"{client.extract_text_or_completion_object(response)=}") + # check response + choices = response.choices + assert isinstance(choices, list) + assert len(choices) == 1 + choice = choices[0] + assert choice.finish_reason == "tool_calls" + message = choice.message + tool_calls = message.tool_calls + assert isinstance(tool_calls, list) + assert len(tool_calls) == 2 + arguments = [tool_call.function.arguments for tool_call in tool_calls] + locations = [json.loads(argument)["location"] for argument in arguments] + print(f"{locations=}") + assert any(["San Francisco" in location for location in locations]) + assert any(["New York" in location for location in locations]) + + @pytest.mark.skipif(skip, reason="openai>=1 not installed") -def test_completion_stream(): +def test_completion_stream() -> None: config_list = config_list_openai_aoai(KEY_LOC) client = OpenAIWrapper(config_list=config_list) response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)