Skip to content

Commit

Permalink
fix(cli/audio): handle non-json response format (#1557)
Browse files Browse the repository at this point in the history
* Fix handling of --response-format in audio transcriptions create command

* handle the string case in audio directly

---------

Co-authored-by: Robert Craigie <robert@craigie.dev>
  • Loading branch information
aurishhammadhafeez and RobertCraigie committed Jul 22, 2024
1 parent af8f606 commit bb7431f
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions src/openai/cli/_api/audio.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Any, Optional, cast
from argparse import ArgumentParser

from .._utils import get_client, print_model
from ..._types import NOT_GIVEN
from .._models import BaseModel
from .._progress import BufferReader
from ...types.audio import Transcription

if TYPE_CHECKING:
from argparse import _SubParsersAction
Expand Down Expand Up @@ -65,30 +67,42 @@ def transcribe(args: CLITranscribeArgs) -> None:
with open(args.file, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")

model = get_client().audio.transcriptions.create(
file=(args.file, buffer_reader),
model=args.model,
language=args.language or NOT_GIVEN,
temperature=args.temperature or NOT_GIVEN,
prompt=args.prompt or NOT_GIVEN,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
response_format=cast(Any, args.response_format),
model = cast(
"Transcription | str",
get_client().audio.transcriptions.create(
file=(args.file, buffer_reader),
model=args.model,
language=args.language or NOT_GIVEN,
temperature=args.temperature or NOT_GIVEN,
prompt=args.prompt or NOT_GIVEN,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
response_format=cast(Any, args.response_format),
),
)
print_model(model)
if isinstance(model, str):
sys.stdout.write(model + "\n")
else:
print_model(model)

@staticmethod
def translate(args: CLITranslationArgs) -> None:
with open(args.file, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")

model = get_client().audio.translations.create(
file=(args.file, buffer_reader),
model=args.model,
temperature=args.temperature or NOT_GIVEN,
prompt=args.prompt or NOT_GIVEN,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
response_format=cast(Any, args.response_format),
model = cast(
"Transcription | str",
get_client().audio.translations.create(
file=(args.file, buffer_reader),
model=args.model,
temperature=args.temperature or NOT_GIVEN,
prompt=args.prompt or NOT_GIVEN,
# casts required because the API is typed for enums
# but we don't want to validate that here for forwards-compat
response_format=cast(Any, args.response_format),
),
)
print_model(model)
if isinstance(model, str):
sys.stdout.write(model + "\n")
else:
print_model(model)

0 comments on commit bb7431f

Please sign in to comment.