From 6aa424d076098312801febd938bd4b5e8baf4851 Mon Sep 17 00:00:00 2001 From: Robert Craigie Date: Tue, 5 Nov 2024 09:31:27 +0000 Subject: [PATCH] fix: add new prediction param to all methods --- src/openai/resources/beta/chat/completions.py | 9 ++++ tests/lib/chat/test_completions.py | 42 ++++++++++++++++--- tests/lib/chat/test_completions_streaming.py | 7 +++- 3 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/openai/resources/beta/chat/completions.py b/src/openai/resources/beta/chat/completions.py index 47bcf42c16..38c09ce8dd 100644 --- a/src/openai/resources/beta/chat/completions.py +++ b/src/openai/resources/beta/chat/completions.py @@ -33,6 +33,7 @@ from ....types.chat.chat_completion_audio_param import ChatCompletionAudioParam from ....types.chat.chat_completion_message_param import ChatCompletionMessageParam from ....types.chat.chat_completion_stream_options_param import ChatCompletionStreamOptionsParam +from ....types.chat.chat_completion_prediction_content_param import ChatCompletionPredictionContentParam from ....types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam __all__ = ["Completions", "AsyncCompletions"] @@ -76,6 +77,7 @@ def parse( modalities: Optional[List[ChatCompletionModality]] | NotGiven = NOT_GIVEN, n: Optional[int] | NotGiven = NOT_GIVEN, parallel_tool_calls: bool | NotGiven = NOT_GIVEN, + prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN, presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, seed: Optional[int] | NotGiven = NOT_GIVEN, service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN, @@ -169,6 +171,7 @@ def parser(raw_completion: ChatCompletion) -> ParsedChatCompletion[ResponseForma "modalities": modalities, "n": n, "parallel_tool_calls": parallel_tool_calls, + "prediction": prediction, "presence_penalty": presence_penalty, "response_format": _type_to_response_format(response_format), "seed": seed, @@ -217,6 +220,7 @@ def stream( modalities: Optional[List[ChatCompletionModality]] | NotGiven = NOT_GIVEN, n: Optional[int] | NotGiven = NOT_GIVEN, parallel_tool_calls: bool | NotGiven = NOT_GIVEN, + prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN, presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, seed: Optional[int] | NotGiven = NOT_GIVEN, service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN, @@ -281,6 +285,7 @@ def stream( modalities=modalities, n=n, parallel_tool_calls=parallel_tool_calls, + prediction=prediction, presence_penalty=presence_penalty, seed=seed, service_tier=service_tier, @@ -343,6 +348,7 @@ async def parse( modalities: Optional[List[ChatCompletionModality]] | NotGiven = NOT_GIVEN, n: Optional[int] | NotGiven = NOT_GIVEN, parallel_tool_calls: bool | NotGiven = NOT_GIVEN, + prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN, presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, seed: Optional[int] | NotGiven = NOT_GIVEN, service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN, @@ -436,6 +442,7 @@ def parser(raw_completion: ChatCompletion) -> ParsedChatCompletion[ResponseForma "modalities": modalities, "n": n, "parallel_tool_calls": parallel_tool_calls, + "prediction": prediction, "presence_penalty": presence_penalty, "response_format": _type_to_response_format(response_format), "seed": seed, @@ -484,6 +491,7 @@ def stream( modalities: Optional[List[ChatCompletionModality]] | NotGiven = NOT_GIVEN, n: Optional[int] | NotGiven = NOT_GIVEN, parallel_tool_calls: bool | NotGiven = NOT_GIVEN, + prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN, presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, seed: Optional[int] | NotGiven = NOT_GIVEN, service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN, @@ -549,6 +557,7 @@ def stream( modalities=modalities, n=n, parallel_tool_calls=parallel_tool_calls, + prediction=prediction, presence_penalty=presence_penalty, seed=seed, service_tier=service_tier, diff --git a/tests/lib/chat/test_completions.py b/tests/lib/chat/test_completions.py index 5cd7b1ee53..48f41eb221 100644 --- a/tests/lib/chat/test_completions.py +++ b/tests/lib/chat/test_completions.py @@ -77,7 +77,12 @@ def test_parse_nothing(client: OpenAI, respx_mock: MockRouter, monkeypatch: pyte system_fingerprint='fp_b40fb1c6fb', usage=CompletionUsage( completion_tokens=37, - completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0), + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=None, + audio_tokens=None, + reasoning_tokens=0, + rejected_prediction_tokens=None + ), prompt_tokens=14, prompt_tokens_details=None, total_tokens=51 @@ -139,7 +144,12 @@ class Location(BaseModel): system_fingerprint='fp_5050236cbd', usage=CompletionUsage( completion_tokens=14, - completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0), + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=None, + audio_tokens=None, + reasoning_tokens=0, + rejected_prediction_tokens=None + ), prompt_tokens=79, prompt_tokens_details=None, total_tokens=93 @@ -203,7 +213,12 @@ class Location(BaseModel): system_fingerprint='fp_b40fb1c6fb', usage=CompletionUsage( completion_tokens=14, - completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0), + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=None, + audio_tokens=None, + reasoning_tokens=0, + rejected_prediction_tokens=None + ), prompt_tokens=88, prompt_tokens_details=None, total_tokens=102 @@ -396,7 +411,12 @@ class CalendarEvent: system_fingerprint='fp_7568d46099', usage=CompletionUsage( completion_tokens=17, - completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0), + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=None, + audio_tokens=None, + reasoning_tokens=0, + rejected_prediction_tokens=None + ), prompt_tokens=92, prompt_tokens_details=None, total_tokens=109 @@ -847,7 +867,12 @@ class Location(BaseModel): system_fingerprint='fp_5050236cbd', usage=CompletionUsage( completion_tokens=14, - completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0), + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=None, + audio_tokens=None, + reasoning_tokens=0, + rejected_prediction_tokens=None + ), prompt_tokens=79, prompt_tokens_details=None, total_tokens=93 @@ -917,7 +942,12 @@ class Location(BaseModel): system_fingerprint='fp_5050236cbd', usage=CompletionUsage( completion_tokens=14, - completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0), + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=None, + audio_tokens=None, + reasoning_tokens=0, + rejected_prediction_tokens=None + ), prompt_tokens=79, prompt_tokens_details=None, total_tokens=93 diff --git a/tests/lib/chat/test_completions_streaming.py b/tests/lib/chat/test_completions_streaming.py index 2846e6d2c3..ab12de44b3 100644 --- a/tests/lib/chat/test_completions_streaming.py +++ b/tests/lib/chat/test_completions_streaming.py @@ -157,7 +157,12 @@ def on_event(stream: ChatCompletionStream[Location], event: ChatCompletionStream system_fingerprint='fp_5050236cbd', usage=CompletionUsage( completion_tokens=14, - completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0), + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=None, + audio_tokens=None, + reasoning_tokens=0, + rejected_prediction_tokens=None + ), prompt_tokens=79, prompt_tokens_details=None, total_tokens=93