Skip to content

Commit

Permalink
Fix beam search with batch processing in Whisper decoding
Browse files Browse the repository at this point in the history
* It ensures that audio features are correctly duplicated across beams for each batch item.
* Added a test for `decode()` that includes a regression test for this.
* Update *.github/workflows/test.yml* to run the new test for `decode()` in tiny.
* This issue was introduced in PR openai#1483.
  • Loading branch information
zuazo committed Jun 1, 2024
1 parent e58f288 commit 537f703
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ jobs:
- uses: actions/checkout@v3
- run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
- run: pip install .["dev"]
- run: pytest --durations=0 -vv -k 'not test_transcribe or test_transcribe[tiny] or test_transcribe[tiny.en]' -m 'not requires_cuda'
- run: pytest --durations=0 -vv -k 'not (test_transcribe or test_decode) or test_transcribe[tiny] or test_transcribe[tiny.en] or test_decode[tiny] or test_decode[tiny.en]' -m 'not requires_cuda'
52 changes: 52 additions & 0 deletions tests/test_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import pytest
import torch

import whisper


@pytest.mark.parametrize("model_name", whisper.available_models())
def test_decode(model_name: str):
# Regression test: batch_size and beam_size should work together
beam_size = 2
batch_size = 2

device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_name).to(device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")

language = "en" if model_name.endswith(".en") else None

options = whisper.DecodingOptions(language=language, beam_size=beam_size)

audio = whisper.load_audio(audio_path)
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(device)

# Create a small batch
batch_mel = mel.unsqueeze(0).repeat(batch_size, 1, 1)

results = model.decode(batch_mel, options)

# Since both examples are the same, results should be identical
assert len(results) == batch_size
assert results[0].text == results[1].text

decoded_text = results[0].text.lower()
assert "my fellow americans" in decoded_text
assert "your country" in decoded_text
assert "do for you" in decoded_text

timing_checked = False
if hasattr(results[0], "segments"):
for segment in results[0].segments:
for timing in segment["words"]:
assert timing["start"] < timing["end"]
if timing["word"].strip(" ,") == "Americans":
assert timing["start"] <= 1.8
assert timing["end"] >= 1.8
timing_checked = True

if hasattr(results[0], "segments"):
assert timing_checked
3 changes: 3 additions & 0 deletions whisper/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,9 @@ def run(self, mel: Tensor) -> List[DecodingResult]:
]

# repeat text tensors by the group size, for beam search or best-of-n sampling
audio_features = audio_features.repeat_interleave(self.n_group, dim=0).to(
audio_features.device
)
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)

# call the main sampling loop
Expand Down

0 comments on commit 537f703

Please sign in to comment.