Skip to content

Commit

Permalink
Merge pull request #96 from smly/fix-batch-processing
Browse files Browse the repository at this point in the history
FIX: Assertion error in batch processing
  • Loading branch information
m-bain authored Feb 22, 2023
2 parents 2b1ffa1 + 27fe502 commit 847a3cd
Showing 1 changed file with 21 additions and 26 deletions.
47 changes: 21 additions & 26 deletions whisperx/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,6 @@ def transcribe_with_vad_parallel(
if mel is None:
mel = log_mel_spectrogram(audio)

output = {"segments": []}

vad_segments = vad_pipeline(audio)
# merge segments to approx 30s inputs to make whisper most appropraite
vad_segments = merge_chunks(vad_segments)
Expand Down Expand Up @@ -428,7 +426,9 @@ def transcribe_with_vad_parallel(
language = kwargs["language"]
task = kwargs["task"]
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
result_segments = post_process_results(

output = post_process_results(
vad_segments,
decode_result,
duration_list,
offset_list,
Expand All @@ -438,29 +438,11 @@ def transcribe_with_vad_parallel(
no_speech_threshold=no_speech_threshold,
logprob_threshold=logprob_threshold,
verbose=verbose)

# post processing: collect outputs
assert len(result_segments) == len(vad_segments)
for sdx, (seg_t, result) in enumerate(zip(vad_segments, result_segments)):
seg_t["text"] = result["text"]
output["segments"].append(
{
"start": seg_t["start"],
"end": seg_t["end"],
"language": result["language"],
"text": result["text"],
"seg-text": [x["text"] for x in result["segments"]],
"seg-start": [x["start"] for x in result["segments"]],
"seg-end": [x["end"] for x in result["segments"]],
}
)

output["language"] = output["segments"][0]["language"]

return output


def post_process_results(
vad_segments,
result_list,
duration_list,
offset_list,
Expand All @@ -478,7 +460,7 @@ def post_process_results(
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
outputs = []
output = {"segments": []}

def add_segment(
*, start: float, end: float, text_tokens: torch.Tensor, result: DecodingResult
Expand All @@ -505,7 +487,7 @@ def add_segment(
print(f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}")

# process the output
for result, segment_duration, timestamp_offset in zip(result_list, duration_list, offset_list):
for seg_t, result, segment_duration, timestamp_offset in zip(vad_segments, result_list, duration_list, offset_list):
all_tokens = []
all_segments = []

Expand Down Expand Up @@ -568,9 +550,22 @@ def add_segment(
seek += segment_shape
all_tokens.extend(tokens.tolist())

outputs.append(dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language))
result = dict(text=tokenizer.decode(all_tokens), segments=all_segments, language=language)
output["segments"].append(
{
"start": seg_t["start"],
"end": seg_t["end"],
"language": result["language"],
"text": result["text"],
"seg-text": [x["text"] for x in result["segments"]],
"seg-start": [x["start"] for x in result["segments"]],
"seg-end": [x["end"] for x in result["segments"]],
}
)

output["language"] = output["segments"][0]["language"]

return outputs
return output


def cli():
Expand Down

0 comments on commit 847a3cd

Please sign in to comment.