Skip to content

Commit

Permalink
fix(audio): correct types for transcriptions / translations (#1755)
Browse files Browse the repository at this point in the history
fix
  • Loading branch information
RobertCraigie committed Sep 27, 2024
1 parent 7a5632f commit 76c1f3f
Show file tree
Hide file tree
Showing 17 changed files with 543 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .stats.yml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
configured_endpoints: 68
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/openai-17ddd746c775ca4d4fbe64e5621ac30756ef09c061ff6313190b6ec162222d4c.yml
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/openai-71e58a77027c67e003fdd1b1ac8ac11557d8bfabc7666d1a827c6b1ca8ab98b5.yml
14 changes: 10 additions & 4 deletions api.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,24 +122,30 @@ from openai.types import AudioModel, AudioResponseFormat
Types:

```python
from openai.types.audio import Transcription
from openai.types.audio import (
Transcription,
TranscriptionSegment,
TranscriptionVerbose,
TranscriptionWord,
TranscriptionCreateResponse,
)
```

Methods:

- <code title="post /audio/transcriptions">client.audio.transcriptions.<a href="./src/openai/resources/audio/transcriptions.py">create</a>(\*\*<a href="src/openai/types/audio/transcription_create_params.py">params</a>) -> <a href="./src/openai/types/audio/transcription.py">Transcription</a></code>
- <code title="post /audio/transcriptions">client.audio.transcriptions.<a href="./src/openai/resources/audio/transcriptions.py">create</a>(\*\*<a href="src/openai/types/audio/transcription_create_params.py">params</a>) -> <a href="./src/openai/types/audio/transcription_create_response.py">TranscriptionCreateResponse</a></code>

## Translations

Types:

```python
from openai.types.audio import Translation
from openai.types.audio import Translation, TranslationVerbose, TranslationCreateResponse
```

Methods:

- <code title="post /audio/translations">client.audio.translations.<a href="./src/openai/resources/audio/translations.py">create</a>(\*\*<a href="src/openai/types/audio/translation_create_params.py">params</a>) -> <a href="./src/openai/types/audio/translation.py">Translation</a></code>
- <code title="post /audio/translations">client.audio.translations.<a href="./src/openai/resources/audio/translations.py">create</a>(\*\*<a href="src/openai/types/audio/translation_create_params.py">params</a>) -> <a href="./src/openai/types/audio/translation_create_response.py">TranslationCreateResponse</a></code>

## Speech

Expand Down
5 changes: 4 additions & 1 deletion src/openai/_utils/_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def assert_signatures_in_sync(
check_func: Callable[..., Any],
*,
exclude_params: set[str] = set(),
description: str = "",
) -> None:
"""Ensure that the signature of the second function matches the first."""

Expand All @@ -39,4 +40,6 @@ def assert_signatures_in_sync(
continue

if errors:
raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors))
raise AssertionError(
f"{len(errors)} errors encountered when comparing signatures{description}:\n\n" + "\n\n".join(errors)
)
155 changes: 146 additions & 9 deletions src/openai/resources/audio/transcriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from __future__ import annotations

from typing import List, Union, Mapping, cast
from typing_extensions import Literal
import logging
from typing import TYPE_CHECKING, List, Union, Mapping, cast
from typing_extensions import Literal, overload, assert_never

import httpx

Expand All @@ -24,9 +25,12 @@
from ...types.audio_model import AudioModel
from ...types.audio.transcription import Transcription
from ...types.audio_response_format import AudioResponseFormat
from ...types.audio.transcription_verbose import TranscriptionVerbose

__all__ = ["Transcriptions", "AsyncTranscriptions"]

log: logging.Logger = logging.getLogger("openai.audio.transcriptions")


class Transcriptions(SyncAPIResource):
@cached_property
Expand All @@ -48,14 +52,53 @@ def with_streaming_response(self) -> TranscriptionsWithStreamingResponse:
"""
return TranscriptionsWithStreamingResponse(self)

@overload
def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Union[Literal["json"], NotGiven] = NOT_GIVEN,
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Transcription: ...

@overload
def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["verbose_json"],
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> TranscriptionVerbose: ...

@overload
def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["text", "srt", "vtt"],
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
response_format: AudioResponseFormat | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
Expand All @@ -64,7 +107,25 @@ def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Transcription:
) -> str: ...

def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
response_format: Union[AudioResponseFormat, NotGiven] = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Transcription | TranscriptionVerbose | str:
"""
Transcribes audio into the input language.
Expand Down Expand Up @@ -124,14 +185,14 @@ def create(
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self._post(
return self._post( # type: ignore[return-value]
"/audio/transcriptions",
body=maybe_transform(body, transcription_create_params.TranscriptionCreateParams),
files=files,
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=Transcription,
cast_to=_get_response_format_type(response_format),
)


Expand All @@ -155,14 +216,15 @@ def with_streaming_response(self) -> AsyncTranscriptionsWithStreamingResponse:
"""
return AsyncTranscriptionsWithStreamingResponse(self)

@overload
async def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Union[Literal["json"], NotGiven] = NOT_GIVEN,
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
response_format: AudioResponseFormat | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
Expand All @@ -171,7 +233,63 @@ async def create(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Transcription:
) -> Transcription: ...

@overload
async def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["verbose_json"],
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> TranscriptionVerbose: ...

@overload
async def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["text", "srt", "vtt"],
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> str: ...

async def create(
self,
*,
file: FileTypes,
model: Union[str, AudioModel],
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
response_format: Union[AudioResponseFormat, NotGiven] = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Transcription | TranscriptionVerbose | str:
"""
Transcribes audio into the input language.
Expand Down Expand Up @@ -238,7 +356,7 @@ async def create(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=Transcription,
cast_to=_get_response_format_type(response_format),
)


Expand Down Expand Up @@ -276,3 +394,22 @@ def __init__(self, transcriptions: AsyncTranscriptions) -> None:
self.create = async_to_streamed_response_wrapper(
transcriptions.create,
)


def _get_response_format_type(
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven,
) -> type[Transcription | TranscriptionVerbose | str]:
if isinstance(response_format, NotGiven) or response_format is None: # pyright: ignore[reportUnnecessaryComparison]
return Transcription

if response_format == "json":
return Transcription
elif response_format == "verbose_json":
return TranscriptionVerbose
elif response_format == "srt" or response_format == "text" or response_format == "vtt":
return str
elif TYPE_CHECKING: # type: ignore[unreachable]
assert_never(response_format)
else:
log.warn("Unexpected audio response format: %s", response_format)
return Transcription
Loading

0 comments on commit 76c1f3f

Please sign in to comment.