Skip to content

Commit

Permalink
fix: add new prediction param to all methods
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertCraigie committed Nov 5, 2024
1 parent b32507d commit 6aa424d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 7 deletions.
9 changes: 9 additions & 0 deletions src/openai/resources/beta/chat/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ....types.chat.chat_completion_audio_param import ChatCompletionAudioParam
from ....types.chat.chat_completion_message_param import ChatCompletionMessageParam
from ....types.chat.chat_completion_stream_options_param import ChatCompletionStreamOptionsParam
from ....types.chat.chat_completion_prediction_content_param import ChatCompletionPredictionContentParam
from ....types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam

__all__ = ["Completions", "AsyncCompletions"]
Expand Down Expand Up @@ -76,6 +77,7 @@ def parse(
modalities: Optional[List[ChatCompletionModality]] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
parallel_tool_calls: bool | NotGiven = NOT_GIVEN,
prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -169,6 +171,7 @@ def parser(raw_completion: ChatCompletion) -> ParsedChatCompletion[ResponseForma
"modalities": modalities,
"n": n,
"parallel_tool_calls": parallel_tool_calls,
"prediction": prediction,
"presence_penalty": presence_penalty,
"response_format": _type_to_response_format(response_format),
"seed": seed,
Expand Down Expand Up @@ -217,6 +220,7 @@ def stream(
modalities: Optional[List[ChatCompletionModality]] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
parallel_tool_calls: bool | NotGiven = NOT_GIVEN,
prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -281,6 +285,7 @@ def stream(
modalities=modalities,
n=n,
parallel_tool_calls=parallel_tool_calls,
prediction=prediction,
presence_penalty=presence_penalty,
seed=seed,
service_tier=service_tier,
Expand Down Expand Up @@ -343,6 +348,7 @@ async def parse(
modalities: Optional[List[ChatCompletionModality]] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
parallel_tool_calls: bool | NotGiven = NOT_GIVEN,
prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -436,6 +442,7 @@ def parser(raw_completion: ChatCompletion) -> ParsedChatCompletion[ResponseForma
"modalities": modalities,
"n": n,
"parallel_tool_calls": parallel_tool_calls,
"prediction": prediction,
"presence_penalty": presence_penalty,
"response_format": _type_to_response_format(response_format),
"seed": seed,
Expand Down Expand Up @@ -484,6 +491,7 @@ def stream(
modalities: Optional[List[ChatCompletionModality]] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
parallel_tool_calls: bool | NotGiven = NOT_GIVEN,
prediction: Optional[ChatCompletionPredictionContentParam] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
seed: Optional[int] | NotGiven = NOT_GIVEN,
service_tier: Optional[Literal["auto", "default"]] | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -549,6 +557,7 @@ def stream(
modalities=modalities,
n=n,
parallel_tool_calls=parallel_tool_calls,
prediction=prediction,
presence_penalty=presence_penalty,
seed=seed,
service_tier=service_tier,
Expand Down
42 changes: 36 additions & 6 deletions tests/lib/chat/test_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ def test_parse_nothing(client: OpenAI, respx_mock: MockRouter, monkeypatch: pyte
system_fingerprint='fp_b40fb1c6fb',
usage=CompletionUsage(
completion_tokens=37,
completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0),
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=None,
audio_tokens=None,
reasoning_tokens=0,
rejected_prediction_tokens=None
),
prompt_tokens=14,
prompt_tokens_details=None,
total_tokens=51
Expand Down Expand Up @@ -139,7 +144,12 @@ class Location(BaseModel):
system_fingerprint='fp_5050236cbd',
usage=CompletionUsage(
completion_tokens=14,
completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0),
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=None,
audio_tokens=None,
reasoning_tokens=0,
rejected_prediction_tokens=None
),
prompt_tokens=79,
prompt_tokens_details=None,
total_tokens=93
Expand Down Expand Up @@ -203,7 +213,12 @@ class Location(BaseModel):
system_fingerprint='fp_b40fb1c6fb',
usage=CompletionUsage(
completion_tokens=14,
completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0),
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=None,
audio_tokens=None,
reasoning_tokens=0,
rejected_prediction_tokens=None
),
prompt_tokens=88,
prompt_tokens_details=None,
total_tokens=102
Expand Down Expand Up @@ -396,7 +411,12 @@ class CalendarEvent:
system_fingerprint='fp_7568d46099',
usage=CompletionUsage(
completion_tokens=17,
completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0),
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=None,
audio_tokens=None,
reasoning_tokens=0,
rejected_prediction_tokens=None
),
prompt_tokens=92,
prompt_tokens_details=None,
total_tokens=109
Expand Down Expand Up @@ -847,7 +867,12 @@ class Location(BaseModel):
system_fingerprint='fp_5050236cbd',
usage=CompletionUsage(
completion_tokens=14,
completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0),
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=None,
audio_tokens=None,
reasoning_tokens=0,
rejected_prediction_tokens=None
),
prompt_tokens=79,
prompt_tokens_details=None,
total_tokens=93
Expand Down Expand Up @@ -917,7 +942,12 @@ class Location(BaseModel):
system_fingerprint='fp_5050236cbd',
usage=CompletionUsage(
completion_tokens=14,
completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0),
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=None,
audio_tokens=None,
reasoning_tokens=0,
rejected_prediction_tokens=None
),
prompt_tokens=79,
prompt_tokens_details=None,
total_tokens=93
Expand Down
7 changes: 6 additions & 1 deletion tests/lib/chat/test_completions_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ def on_event(stream: ChatCompletionStream[Location], event: ChatCompletionStream
system_fingerprint='fp_5050236cbd',
usage=CompletionUsage(
completion_tokens=14,
completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0),
completion_tokens_details=CompletionTokensDetails(
accepted_prediction_tokens=None,
audio_tokens=None,
reasoning_tokens=0,
rejected_prediction_tokens=None
),
prompt_tokens=79,
prompt_tokens_details=None,
total_tokens=93
Expand Down

0 comments on commit 6aa424d

Please sign in to comment.