From 9d9cb23d41875e66b31ee759c368473a6e1a761e Mon Sep 17 00:00:00 2001 From: jiltseb Date: Mon, 2 Sep 2024 15:49:09 +0000 Subject: [PATCH 01/18] updates in asr post processing for diarized transcription --- aana/core/models/asr.py | 6 + aana/core/models/whisper.py | 2 +- aana/processors/speaker.py | 370 ++++++++++++++++++++++++++++++++++++ 3 files changed, 377 insertions(+), 1 deletion(-) create mode 100644 aana/processors/speaker.py diff --git a/aana/core/models/asr.py b/aana/core/models/asr.py index 8754a162..0190f37f 100644 --- a/aana/core/models/asr.py +++ b/aana/core/models/asr.py @@ -31,11 +31,13 @@ class AsrWord(BaseModel): Attributes: word (str): The word text. + speaker (str| None): Speaker label for the word. time_interval (TimeInterval): Time interval of the word. alignment_confidence (float): Alignment confidence of the word, >= 0.0 and <= 1.0. """ word: str = Field(description="The word text") + speaker: str | None = Field(description="Speaker label for the word") time_interval: TimeInterval = Field(description="Time interval of the word") alignment_confidence: float = Field( ge=0.0, le=1.0, description="Alignment confidence of the word" @@ -52,6 +54,7 @@ def from_whisper(cls, whisper_word: WhisperWord) -> "AsrWord": AsrWord: The converted AsrWord. """ return cls( + speaker=None, word=whisper_word.word, time_interval=TimeInterval(start=whisper_word.start, end=whisper_word.end), alignment_confidence=whisper_word.probability, @@ -73,6 +76,7 @@ class AsrSegment(BaseModel): confidence (float | None): Confidence of the segment. no_speech_confidence (float | None): Chance of being a silence segment. words (list[AsrWord]): List of words in the segment. Default is []. + speaker (str | None): Speaker label. Default is None. """ text: str = Field(description="The text of the segment (transcript/translation)") @@ -86,6 +90,7 @@ class AsrSegment(BaseModel): words: list[AsrWord] = Field( description="List of words in the segment", default_factory=list ) + speaker: str | None = Field(None, description="speaker label of the segment") @classmethod def from_whisper(cls, whisper_segment: WhisperSegment) -> "AsrSegment": @@ -116,6 +121,7 @@ def from_whisper(cls, whisper_segment: WhisperSegment) -> "AsrSegment": confidence=confidence, no_speech_confidence=no_speech_confidence, words=words, + speaker=None, ) model_config = ConfigDict( diff --git a/aana/core/models/whisper.py b/aana/core/models/whisper.py index 54c1a615..cda91414 100644 --- a/aana/core/models/whisper.py +++ b/aana/core/models/whisper.py @@ -44,7 +44,7 @@ class WhisperParams(BaseModel): ), ) word_timestamps: bool = Field( - default=False, description="Whether to extract word-level timestamps." + default=True, description="Whether to extract word-level timestamps." ) vad_filter: bool = Field( default=True, diff --git a/aana/processors/speaker.py b/aana/processors/speaker.py new file mode 100644 index 00000000..34727c68 --- /dev/null +++ b/aana/processors/speaker.py @@ -0,0 +1,370 @@ +from collections import defaultdict + +from aana.core.models.asr import AsrSegment, AsrWord +from aana.core.models.speaker import SpeakerDiarizationSegment +from aana.core.models.time import TimeInterval +from aana.deployments.pyannote_speaker_diarization_deployment import ( + SpeakerDiarizationOutput, +) +from aana.deployments.whisper_deployment import WhisperOutput + +# Utility functions for speaker-related processing + +# Define sentence ending punctuations: +sentence_ending_punctuations = ".?!" + + +# AsrSegment and AsrWord has a speaker label that defaults to None. +def assign_word_speakers( + diarized_output: SpeakerDiarizationOutput, + transcription: WhisperOutput, + fill_nearest: bool = False, +) -> WhisperOutput: + """Assigns speaker labels to each segment and word in the transcription based on diarized output. + + Parameters: + - diarized_output (SpeakerDiarizationOutput): Contains speaker diarization segments. + - transcription (WhisperOutput): Transcription data with segments, text, and language_info. + - fill_nearest (bool): If True, assigns the closest speaker even if there's no positive overlap. Default is False. + + Returns: + - transcription (WhisperOutput): Transcription updated in-place with the assigned speaker labels. + """ + for segment in transcription["segments"]: + # Assign speaker to segment + segment.speaker = get_speaker_for_interval( + diarized_output["segments"], + segment.time_interval.start, + segment.time_interval.end, + fill_nearest, + ) + + # Assign speakers to words within the segment + if segment.words: + for word in segment.words: + word.speaker = get_speaker_for_interval( + diarized_output["segments"], + word.time_interval.start, + word.time_interval.end, + fill_nearest, + ) + + return transcription + + +def get_speaker_for_interval( + sd_segments: list[SpeakerDiarizationSegment], + start_time: float, + end_time: float, + fill_nearest: bool, +) -> str | None: + """Determines the speaker for a given time interval based on diarized segments. + + Parameters: + - sd_segments (list[SpeakerDiarizationSegment]): List of speaker diarization segments. + - start_time (float): Start time of the interval. + - end_time (float): End time of the interval. + - fill_nearest (bool): If True, selects the closest speaker even with no overlap. + + Returns: + - str | None: The identified speaker label, or None if no speaker is found. + """ + overlaps = [] + + for sd_segment in sd_segments: + interval_start = sd_segment.time_interval.start + interval_end = sd_segment.time_interval.end + + # Calculate overlap duration + overlap_start = max(start_time, interval_start) + overlap_end = min(end_time, interval_end) + overlap_duration = max(0.0, overlap_end - overlap_start) + + if overlap_duration > 0 or fill_nearest: + # Calculate union duration for potential future use + # union_duration = max(end_time, interval_end) - min(start_time, interval_start) + distance = float( + min(abs(start_time - interval_end), abs(end_time - interval_start)) + ) + + overlaps.append( + { + "speaker": sd_segment.speaker, + "overlap_duration": overlap_duration, + "distance": distance, + } + ) + + if not overlaps: + return None + else: + # Select the speaker with the maximum overlap duration (or minimal distance) + best_match = max( + overlaps, + key=lambda x: (x["overlap_duration"], -x["distance"]) + if not fill_nearest + else (-x["distance"]), + ) + return best_match["speaker"] + + +def get_first_word_idx_of_sentence( + word_idx: int, + word_list: list[str], + speaker_list: list[str | None], + max_words: int, +) -> int: + """Get the index of the first word of the sentence in the given range.""" + left_idx = word_idx + while ( + left_idx > 0 + and word_idx - left_idx < max_words + and speaker_list[left_idx - 1] == speaker_list[left_idx] + and word_list[left_idx - 1][-1] not in sentence_ending_punctuations + ): + left_idx -= 1 + + return ( + left_idx + if left_idx == 0 or word_list[left_idx - 1][-1] in sentence_ending_punctuations + else -1 + ) + + +def get_last_word_idx_of_sentence( + word_idx: int, word_list: list[str], max_words: int +) -> int: + """Get the index of the last word of the sentence in the given range.""" + right_idx = word_idx + while ( + right_idx < len(word_list) + and right_idx - word_idx < max_words + and word_list[right_idx][-1] not in sentence_ending_punctuations + ): + right_idx += 1 + + return ( + right_idx + if right_idx == len(word_list) - 1 + or word_list[right_idx][-1] in sentence_ending_punctuations + else -1 + ) + + +def find_nearest_speaker( + index: int, word_speaker_mapping: list[AsrWord], reverse: bool = False +) -> str | None: + """Find the nearest speaker label in the word_speaker_mapping either forward or backward. + + Parameters: + - index (int): The index to start searching from. + - word_speaker_mapping (list[AsrWord]): List of word-speaker mappings. + - reverse (bool): Search backwards if True; forwards if False. Default is False. + + Returns: + - str | None: The nearest speaker found or None if not found. + """ + step = -1 if reverse else 1 + for i in range(index, len(word_speaker_mapping) if not reverse else -1, step): + if word_speaker_mapping[i].speaker: + return word_speaker_mapping[i].speaker + return None + + +def align_with_punctuation( + transcription: WhisperOutput, max_words_in_sentence: int = 50 +) -> list[AsrWord]: + """Aligns speaker labels with sentence boundaries defined by punctuation. + + Parameters: + - transcription (WhisperOutput): transcription with speaker information. + - max_words_in_sentence (int): Maximum number of words allowed in a sentence. + + Returns: + - list[AsrWord]: Realigned word-speaker mappings. + """ + new_segments = [segment.words for segment in transcription["segments"]] + word_speaker_mapping = [word for segment in new_segments for word in segment] + words_list = [item.word for item in word_speaker_mapping] + speaker_list = [item.speaker for item in word_speaker_mapping] + + # Fill missing speaker labels by finding the nearest speaker + for i, item in enumerate(word_speaker_mapping): + if item.speaker is None: + item.speaker = find_nearest_speaker(i, word_speaker_mapping, reverse=i > 0) + speaker_list[i] = item.speaker + + # Align speakers with sentence boundaries + k = 0 + while k < len(word_speaker_mapping): + if ( + k < len(word_speaker_mapping) - 1 + and speaker_list[k] != speaker_list[k + 1] + and words_list[k][-1] not in sentence_ending_punctuations + ): + left_idx = get_first_word_idx_of_sentence( + k, words_list, speaker_list, max_words_in_sentence + ) + right_idx = get_last_word_idx_of_sentence( + k, words_list, max_words_in_sentence - (k - left_idx) + ) + + if min(left_idx, right_idx) == -1: + k += 1 + continue + + spk_labels = speaker_list[left_idx : right_idx + 1] + mod_speaker = max(set(spk_labels), key=spk_labels.count) + + if spk_labels.count(mod_speaker) >= len(spk_labels) // 2: + speaker_list[left_idx : right_idx + 1] = [mod_speaker] * ( + right_idx - left_idx + 1 + ) + k = right_idx + + k += 1 + + # Realign the speaker labels in the original word_speaker_mapping + for i, item in enumerate(word_speaker_mapping): + item.speaker = speaker_list[i] + + return word_speaker_mapping + + +def create_speaker_segments( + word_list: list[AsrWord], merge: bool = False +) -> list[AsrSegment]: + """Creates speaker segments from a list of words with speaker annotations and timing information. + + Parameters: + - word_list (List[AsrWord]): A list of words with associated speaker and timing details. + - merge (bool): If True, merges consecutive segments from the same speaker into one segment. + + Returns: + - List[AsrSegment]: A list of segments where each segment groups words spoken by the same speaker. + """ + if not word_list: + return [] + + # Initialize variables to track current speaker and segment + current_speaker = None + current_segment = None + final_segments: list[AsrSegment] = [] + + for word_info in word_list: + # Default speaker assignment to the last known speaker if missing + speaker = word_info.speaker or current_speaker + + # Check for speaker change or start of a new segment + if speaker != current_speaker: + if current_segment: + final_segments.append(current_segment) + + # Start a new segment + # TODO: Also get the original confidence measurements from segments and update it. + current_segment = AsrSegment( + time_interval=TimeInterval( + start=word_info.time_interval.start, end=word_info.time_interval.end + ), + text=word_info.word, + speaker=speaker, + words=[word_info], + confidence=None, + no_speech_confidence=None, + ) + current_speaker = speaker + else: + if current_segment: + # Update the current segment for continuous speaker + current_segment.time_interval.end = word_info.time_interval.end + current_segment.text += f" {word_info.word}" # Add space between words + current_segment.words.append(word_info) + + # Append the last segment if it exists + if current_segment: + final_segments.append(current_segment) + + # Optional merging of consecutive segments by the same speaker + if merge: + final_segments = merge_consecutive_speaker_segments(final_segments) + + return final_segments + + +def merge_consecutive_speaker_segments( + segments: list[AsrSegment], +) -> list[AsrSegment]: + """Merges consecutive segments that have the same speaker into a single segment. + + Parameters: + - segments (List[AsrSegment]): The initial list of segments. + + Returns: + - List[AsrSegment]: A new list of merged segments. + """ + if not segments: + return [] + + merged_segments: list[AsrSegment] = [] + mapping: defaultdict[str, str] = defaultdict(str) + + current_segment = segments[0] + speaker_counter = 0 + + for next_segment in segments[1:]: + if next_segment.speaker == current_segment.speaker: + # Merge segments + current_segment.time_interval.end = next_segment.time_interval.end + current_segment.text += f" {next_segment.text}" + current_segment.words.extend(next_segment.words) + else: + # Assign unique speaker labels and finalize the current segment + if current_segment.speaker: + current_segment.speaker = mapping.setdefault( + current_segment.speaker, f"SPEAKER_{speaker_counter:02d}" + ) + for word in current_segment.words: + word.speaker = current_segment.speaker + merged_segments.append(current_segment) + current_segment = next_segment + speaker_counter += 1 + + # Handle the last segment + if current_segment.speaker: + current_segment.speaker = mapping.setdefault( + current_segment.speaker, f"SPEAKER_{speaker_counter:02d}" + ) + for word in current_segment.words: + word.speaker = current_segment.speaker + merged_segments.append(current_segment) + + return merged_segments + + +# Full method +def asr_postprocessing_for_diarization( + diarized_output: SpeakerDiarizationOutput, transcription: WhisperOutput +) -> WhisperOutput: + """Perform diarized transcription based on individual deployments. + + Parameters: + - diarized_output (SpeakerDiarizationOutput): Contains speaker diarization segments. + - transcription (WhisperOutput): Transcription data with segments, text, and language_info. + + Returns: + - transcription (WhisperOutput): Updated transcription with diarized information per segment/word. + + """ + # 1. Assign speaker labels to each segment and each word in WhisperOutput based on SpeakerDiarizationOutput. + + speaker_labelled_transcription = assign_word_speakers( + diarized_output, transcription + ) + # 2. Aligns the speakers with the punctuations: + word_speaker_mapping = align_with_punctuation(speaker_labelled_transcription) + + # 3. Create ASR segments by combining the AsrWord with speaker information + # and optionally combine segments based on speaker info. + updated_transcription = create_speaker_segments(word_speaker_mapping) + transcription["segments"] = updated_transcription + return transcription From 5b3ab99d002562560d19e7d3add35daf581fc742 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Sun, 8 Sep 2024 20:35:57 +0000 Subject: [PATCH 02/18] diarized_transcription; initial commit. Post processin g scripts, example notebook, unit test, test files and modifies sd deployment --- aana/core/models/whisper.py | 2 +- ...pyannote_speaker_diarization_deployment.py | 26 +- aana/processors/speaker.py | 107 ++- ...pyannote_speaker_diarization_deployment.py | 8 +- aana/tests/files/expected/sd/sd_sample.json | 75 +- .../whisper/whisper_medium_sd_sample.wav.json | 1 + .../whisper_medium_sd_sample.wav_diar.json | 827 ++++++++++++++++++ aana/tests/units/test_speaker.py | 67 ++ .../diarized_transcription_example.ipynb | 292 +++++++ 9 files changed, 1389 insertions(+), 16 deletions(-) create mode 100644 aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav.json create mode 100644 aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_diar.json create mode 100644 aana/tests/units/test_speaker.py create mode 100644 notebooks/diarized_transcription_example.ipynb diff --git a/aana/core/models/whisper.py b/aana/core/models/whisper.py index cda91414..54c1a615 100644 --- a/aana/core/models/whisper.py +++ b/aana/core/models/whisper.py @@ -44,7 +44,7 @@ class WhisperParams(BaseModel): ), ) word_timestamps: bool = Field( - default=True, description="Whether to extract word-level timestamps." + default=False, description="Whether to extract word-level timestamps." ) vad_filter: bool = Field( default=True, diff --git a/aana/deployments/pyannote_speaker_diarization_deployment.py b/aana/deployments/pyannote_speaker_diarization_deployment.py index ee0e1b90..0fbdc9e7 100644 --- a/aana/deployments/pyannote_speaker_diarization_deployment.py +++ b/aana/deployments/pyannote_speaker_diarization_deployment.py @@ -1,3 +1,4 @@ +import logging from typing import Any, TypedDict import torch @@ -15,6 +16,7 @@ from aana.core.models.time import TimeInterval from aana.deployments.base_deployment import BaseDeployment from aana.exceptions.runtime import InferenceException +from aana.processors.speaker import combine_homogeneous_speaker_segs class SpeakerDiarizationOutput(TypedDict): @@ -69,9 +71,21 @@ async def apply_config(self, config: dict[str, Any]): if torch.cuda.is_available(): torch.cuda.manual_seed_all(42) - # load model using pyannote Pipeline - self.diarize_model = Pipeline.from_pretrained(self.model_id) - self.diarize_model.to(torch.device(self.device)) + try: + # load model using pyannote Pipeline + self.diarize_model = Pipeline.from_pretrained(self.model_id) + + # Check if model, log error if None. + if self.diarize_model is None: + logging.error( + f"Accept user agreements at huggingface for the model: {self.model_id}" + ) + + else: + self.diarize_model.to(torch.device(self.device)) + + except Exception as e: + raise InferenceException(self.model_id) from e async def __inference( self, audio: Audio, params: PyannoteSpeakerDiarizationParams @@ -134,4 +148,8 @@ async def diarize( ) ) - return SpeakerDiarizationOutput(segments=speaker_diarization_segments) + # Perform post-processing to combine homogeneous speaker segments. + processed_speaker_diarization_segments = combine_homogeneous_speaker_segs( + SpeakerDiarizationOutput(segments=speaker_diarization_segments) + ) + return processed_speaker_diarization_segments diff --git a/aana/processors/speaker.py b/aana/processors/speaker.py index 34727c68..0df24b26 100644 --- a/aana/processors/speaker.py +++ b/aana/processors/speaker.py @@ -1,15 +1,43 @@ from collections import defaultdict +from typing import TypedDict -from aana.core.models.asr import AsrSegment, AsrWord +from aana.core.models.asr import ( + AsrSegment, + AsrTranscription, + AsrTranscriptionInfo, + AsrWord, +) from aana.core.models.speaker import SpeakerDiarizationSegment from aana.core.models.time import TimeInterval -from aana.deployments.pyannote_speaker_diarization_deployment import ( - SpeakerDiarizationOutput, -) -from aana.deployments.whisper_deployment import WhisperOutput # Utility functions for speaker-related processing + +# redefine SpeakerDiarizationOutput and WhisperOutput to prevent circular imports +class SpeakerDiarizationOutput(TypedDict): + """The output of the Speaker Diarization model. + + Attributes: + segments (list[SpeakerDiarizationSegment]): The Speaker Diarization segments. + """ + + segments: list[SpeakerDiarizationSegment] + + +class WhisperOutput(TypedDict): + """The output of the whisper model. + + Attributes: + segments (list[AsrSegment]): The ASR segments. + transcription_info (AsrTranscriptionInfo): The ASR transcription info. + transcription (AsrTranscription): The ASR transcription. + """ + + segments: list[AsrSegment] + transcription_info: AsrTranscriptionInfo + transcription: AsrTranscription + + # Define sentence ending punctuations: sentence_ending_punctuations = ".?!" @@ -277,7 +305,7 @@ def create_speaker_segments( if current_segment: # Update the current segment for continuous speaker current_segment.time_interval.end = word_info.time_interval.end - current_segment.text += f" {word_info.word}" # Add space between words + current_segment.text += f"{word_info.word}" # Add space between words current_segment.words.append(word_info) # Append the last segment if it exists @@ -341,7 +369,7 @@ def merge_consecutive_speaker_segments( return merged_segments -# Full method +# Full Method def asr_postprocessing_for_diarization( diarized_output: SpeakerDiarizationOutput, transcription: WhisperOutput ) -> WhisperOutput: @@ -368,3 +396,68 @@ def asr_postprocessing_for_diarization( updated_transcription = create_speaker_segments(word_speaker_mapping) transcription["segments"] = updated_transcription return transcription + + +# speaker diarization model occationally produce overlapping chunks/ same speaker segments, +# below function combines them properly + + +def combine_homogeneous_speaker_segs( + diarized_output: SpeakerDiarizationOutput, +) -> SpeakerDiarizationOutput: + """Combines segments with the same speaker into homogeneous speaker segments, ensuring no overlapping times. + + Parameters: + - diarized_output (SpeakerDiarizationOutput): Input with segments that may have overlapping times. + + Returns: + - SpeakerDiarizationOutput: Output with combined homogeneous speaker segments. + """ + combined_segments: list = [] + current_speaker = None + current_segment = None + + for segment in sorted( + diarized_output["segments"], key=lambda x: x.time_interval.start + ): + speaker = segment.speaker + + # If there's a speaker change or current_segment is None, finalize current and start a new one + if current_speaker != speaker: + # Finalize the current segment if it exists + if current_segment: + combined_segments.append(current_segment) + + current_speaker = speaker + + # Start a new segment for the current speaker + current_segment = SpeakerDiarizationSegment( + time_interval=TimeInterval( + start=segment.time_interval.start, end=segment.time_interval.end + ), + speaker=current_speaker, + ) + else: + if current_segment: + # Extend the current segment for the same speaker + # Ensure there is no overlap; take the maximum of current end and incoming end + current_segment.time_interval.end = max( + current_segment.time_interval.end, segment.time_interval.end + ) + + # Adjust the start of the next segment if there's any overlap + if ( + current_segment + and len(combined_segments) > 0 + and combined_segments[-1].time_interval.end + > current_segment.time_interval.start + ): + current_segment.time_interval.start = combined_segments[ + -1 + ].time_interval.end + + # Add the last segment if it exists + if current_segment: + combined_segments.append(current_segment) + + return SpeakerDiarizationOutput(segments=combined_segments) diff --git a/aana/tests/deployments/test_pyannote_speaker_diarization_deployment.py b/aana/tests/deployments/test_pyannote_speaker_diarization_deployment.py index 70eed5d1..97a0391f 100644 --- a/aana/tests/deployments/test_pyannote_speaker_diarization_deployment.py +++ b/aana/tests/deployments/test_pyannote_speaker_diarization_deployment.py @@ -1,7 +1,7 @@ # ruff: noqa: S101 from importlib import resources from pathlib import Path - +import json import pytest from aana.core.models.audio import Audio @@ -21,7 +21,7 @@ max_ongoing_requests=1000, ray_actor_options={"num_gpus": 0.05}, user_config=PyannoteSpeakerDiarizationConfig( - model_name=("pyannote/speaker-diarization-3.1"), + model_id=("pyannote/speaker-diarization-3.1"), sample_rate=16000, ).model_dump(mode="json"), ), @@ -55,5 +55,7 @@ async def test_speaker_diarization(self, setup_deployment, audio_file): output = await handle.diarize(audio=audio) output = pydantic_to_dict(output) - + # save or print this, evenif it is different. + # with open(expected_output_path, "w") as json_file: + # json.dump(output, json_file, indent=4) verify_deployment_results(expected_output_path, output) diff --git a/aana/tests/files/expected/sd/sd_sample.json b/aana/tests/files/expected/sd/sd_sample.json index 367d4f14..0a096725 100644 --- a/aana/tests/files/expected/sd/sd_sample.json +++ b/aana/tests/files/expected/sd/sd_sample.json @@ -1 +1,74 @@ -{"segments": [{"time_interval": {"start": 6.730343750000001, "end": 7.16909375}, "speaker": "SPEAKER_01"}, {"time_interval": {"start": 7.16909375, "end": 7.185968750000001}, "speaker": "SPEAKER_02"}, {"time_interval": {"start": 7.59096875, "end": 8.316593750000003}, "speaker": "SPEAKER_01"}, {"time_interval": {"start": 8.316593750000003, "end": 9.919718750000001}, "speaker": "SPEAKER_02"}, {"time_interval": {"start": 9.919718750000001, "end": 10.93221875}, "speaker": "SPEAKER_01"}, {"time_interval": {"start": 10.45971875, "end": 14.745968750000003}, "speaker": "SPEAKER_02"}, {"time_interval": {"start": 10.93221875, "end": 10.98284375}, "speaker": "SPEAKER_00"}, {"time_interval": {"start": 14.30721875, "end": 17.88471875}, "speaker": "SPEAKER_00"}, {"time_interval": {"start": 18.01971875, "end": 21.512843750000002}, "speaker": "SPEAKER_02"}, {"time_interval": {"start": 18.15471875, "end": 18.44159375}, "speaker": "SPEAKER_00"}, {"time_interval": {"start": 21.765968750000003, "end": 28.49909375}, "speaker": "SPEAKER_00"}, {"time_interval": {"start": 27.85784375, "end": 29.96721875}, "speaker": "SPEAKER_02"}]} \ No newline at end of file +{ + "segments": [ + { + "time_interval": { + "start": 6.730343750000001, + "end": 7.16909375 + }, + "speaker": "SPEAKER_01" + }, + { + "time_interval": { + "start": 7.16909375, + "end": 7.185968750000001 + }, + "speaker": "SPEAKER_02" + }, + { + "time_interval": { + "start": 7.59096875, + "end": 8.316593750000003 + }, + "speaker": "SPEAKER_01" + }, + { + "time_interval": { + "start": 8.316593750000003, + "end": 9.919718750000001 + }, + "speaker": "SPEAKER_02" + }, + { + "time_interval": { + "start": 9.919718750000001, + "end": 10.93221875 + }, + "speaker": "SPEAKER_01" + }, + { + "time_interval": { + "start": 10.93221875, + "end": 14.745968750000003 + }, + "speaker": "SPEAKER_02" + }, + { + "time_interval": { + "start": 14.745968750000003, + "end": 17.88471875 + }, + "speaker": "SPEAKER_00" + }, + { + "time_interval": { + "start": 18.01971875, + "end": 21.512843750000002 + }, + "speaker": "SPEAKER_02" + }, + { + "time_interval": { + "start": 21.512843750000002, + "end": 28.49909375 + }, + "speaker": "SPEAKER_00" + }, + { + "time_interval": { + "start": 28.49909375, + "end": 29.96721875 + }, + "speaker": "SPEAKER_02" + } + ] +} \ No newline at end of file diff --git a/aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav.json b/aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav.json new file mode 100644 index 00000000..18411406 --- /dev/null +++ b/aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav.json @@ -0,0 +1 @@ +{"segments":[{"text":" Hello. Hello. Oh, hello. I didn't know you were there. Neither did I. Okay, I thought, you know, I heard it deep.","time_interval":{"start":6.38,"end":12.38},"confidence":0.8329984157521475,"no_speech_confidence":0.012033582665026188,"words":[{"word":" Hello.","speaker":null,"time_interval":{"start":6.38,"end":7.0},"alignment_confidence":0.6853185296058655},{"word":" Hello.","speaker":null,"time_interval":{"start":7.5,"end":7.84},"alignment_confidence":0.7124693989753723},{"word":" Oh,","speaker":null,"time_interval":{"start":8.3,"end":8.48},"alignment_confidence":0.8500092029571533},{"word":" hello.","speaker":null,"time_interval":{"start":8.5,"end":8.76},"alignment_confidence":0.9408962726593018},{"word":" I","speaker":null,"time_interval":{"start":8.92,"end":8.94},"alignment_confidence":0.9970371723175049},{"word":" didn't","speaker":null,"time_interval":{"start":8.94,"end":9.14},"alignment_confidence":0.9951577484607697},{"word":" know","speaker":null,"time_interval":{"start":9.14,"end":9.28},"alignment_confidence":0.9988951086997986},{"word":" you","speaker":null,"time_interval":{"start":9.28,"end":9.4},"alignment_confidence":0.9883798360824585},{"word":" were","speaker":null,"time_interval":{"start":9.4,"end":9.5},"alignment_confidence":0.9613558053970337},{"word":" there.","speaker":null,"time_interval":{"start":9.5,"end":9.68},"alignment_confidence":0.9957772493362427},{"word":" Neither","speaker":null,"time_interval":{"start":9.98,"end":10.22},"alignment_confidence":0.7853943109512329},{"word":" did","speaker":null,"time_interval":{"start":10.22,"end":10.56},"alignment_confidence":0.991905689239502},{"word":" I.","speaker":null,"time_interval":{"start":10.56,"end":10.64},"alignment_confidence":0.9961293935775757},{"word":" Okay,","speaker":null,"time_interval":{"start":10.64,"end":11.0},"alignment_confidence":0.5810848474502563},{"word":" I","speaker":null,"time_interval":{"start":11.1,"end":11.3},"alignment_confidence":0.8348276615142822},{"word":" thought,","speaker":null,"time_interval":{"start":11.3,"end":11.52},"alignment_confidence":0.8927552700042725},{"word":" you","speaker":null,"time_interval":{"start":11.7,"end":11.82},"alignment_confidence":0.9182931780815125},{"word":" know,","speaker":null,"time_interval":{"start":11.82,"end":11.9},"alignment_confidence":0.9988963603973389},{"word":" I","speaker":null,"time_interval":{"start":11.96,"end":11.98},"alignment_confidence":0.9913604855537415},{"word":" heard","speaker":null,"time_interval":{"start":11.98,"end":12.12},"alignment_confidence":0.992224931716919},{"word":" it","speaker":null,"time_interval":{"start":12.12,"end":12.22},"alignment_confidence":0.41407883167266846},{"word":" deep.","speaker":null,"time_interval":{"start":12.22,"end":12.38},"alignment_confidence":0.7947677969932556}],"speaker":null},{"text":" This is Diane in New Jersey. And I'm Sheila in Texas, originally from Chicago.","time_interval":{"start":12.56,"end":17.42},"confidence":0.8329984157521475,"no_speech_confidence":0.012033582665026188,"words":[{"word":" This","speaker":null,"time_interval":{"start":12.56,"end":12.66},"alignment_confidence":0.9929931163787842},{"word":" is","speaker":null,"time_interval":{"start":12.66,"end":12.82},"alignment_confidence":0.9977980852127075},{"word":" Diane","speaker":null,"time_interval":{"start":12.82,"end":13.12},"alignment_confidence":0.884567379951477},{"word":" in","speaker":null,"time_interval":{"start":13.12,"end":13.4},"alignment_confidence":0.9688433408737183},{"word":" New","speaker":null,"time_interval":{"start":13.4,"end":13.54},"alignment_confidence":0.9953000545501709},{"word":" Jersey.","speaker":null,"time_interval":{"start":13.54,"end":13.94},"alignment_confidence":0.9994387030601501},{"word":" And","speaker":null,"time_interval":{"start":14.36,"end":14.52},"alignment_confidence":0.9156690239906311},{"word":" I'm","speaker":null,"time_interval":{"start":14.52,"end":14.72},"alignment_confidence":0.9674701690673828},{"word":" Sheila","speaker":null,"time_interval":{"start":14.72,"end":15.06},"alignment_confidence":0.967194139957428},{"word":" in","speaker":null,"time_interval":{"start":15.06,"end":15.5},"alignment_confidence":0.9479336738586426},{"word":" Texas,","speaker":null,"time_interval":{"start":15.5,"end":15.98},"alignment_confidence":0.9921407699584961},{"word":" originally","speaker":null,"time_interval":{"start":16.16,"end":16.7},"alignment_confidence":0.998430073261261},{"word":" from","speaker":null,"time_interval":{"start":16.7,"end":17.08},"alignment_confidence":0.9941904544830322},{"word":" Chicago.","speaker":null,"time_interval":{"start":17.08,"end":17.42},"alignment_confidence":0.9997063279151917}],"speaker":null},{"text":" Oh, I'm originally from Chicago also. I'm in New Jersey now, though.","time_interval":{"start":18.04,"end":21.34},"confidence":0.8329984157521475,"no_speech_confidence":0.012033582665026188,"words":[{"word":" Oh,","speaker":null,"time_interval":{"start":18.04,"end":18.36},"alignment_confidence":0.8741865754127502},{"word":" I'm","speaker":null,"time_interval":{"start":18.4,"end":18.62},"alignment_confidence":0.9828968346118927},{"word":" originally","speaker":null,"time_interval":{"start":18.62,"end":18.86},"alignment_confidence":0.9990705847740173},{"word":" from","speaker":null,"time_interval":{"start":18.86,"end":19.16},"alignment_confidence":0.9991160035133362},{"word":" Chicago","speaker":null,"time_interval":{"start":19.16,"end":19.56},"alignment_confidence":0.9996306896209717},{"word":" also.","speaker":null,"time_interval":{"start":19.56,"end":19.98},"alignment_confidence":0.8224349021911621},{"word":" I'm","speaker":null,"time_interval":{"start":20.18,"end":20.24},"alignment_confidence":0.9975096881389618},{"word":" in","speaker":null,"time_interval":{"start":20.24,"end":20.34},"alignment_confidence":0.9985995888710022},{"word":" New","speaker":null,"time_interval":{"start":20.34,"end":20.46},"alignment_confidence":0.9962742328643799},{"word":" Jersey","speaker":null,"time_interval":{"start":20.46,"end":20.76},"alignment_confidence":0.99826979637146},{"word":" now,","speaker":null,"time_interval":{"start":20.76,"end":21.14},"alignment_confidence":0.9930222630500793},{"word":" though.","speaker":null,"time_interval":{"start":21.22,"end":21.34},"alignment_confidence":0.9972519278526306}],"speaker":null},{"text":" Well, there is not much difference. At least, you know, they all call me a Yankee down here, so what can I say?","time_interval":{"start":21.74,"end":28.32},"confidence":0.8329984157521475,"no_speech_confidence":0.012033582665026188,"words":[{"word":" Well,","speaker":null,"time_interval":{"start":21.74,"end":22.14},"alignment_confidence":0.9893079400062561},{"word":" there","speaker":null,"time_interval":{"start":22.26,"end":22.54},"alignment_confidence":0.9968032836914062},{"word":" is","speaker":null,"time_interval":{"start":22.54,"end":22.7},"alignment_confidence":0.45982280373573303},{"word":" not","speaker":null,"time_interval":{"start":22.7,"end":22.86},"alignment_confidence":0.9138479828834534},{"word":" much","speaker":null,"time_interval":{"start":22.86,"end":23.24},"alignment_confidence":0.9922243356704712},{"word":" difference.","speaker":null,"time_interval":{"start":23.24,"end":23.76},"alignment_confidence":0.9899066090583801},{"word":" At","speaker":null,"time_interval":{"start":24.06,"end":24.1},"alignment_confidence":0.9863474369049072},{"word":" least,","speaker":null,"time_interval":{"start":24.1,"end":24.36},"alignment_confidence":0.9989750385284424},{"word":" you","speaker":null,"time_interval":{"start":24.56,"end":25.02},"alignment_confidence":0.9212563037872314},{"word":" know,","speaker":null,"time_interval":{"start":25.02,"end":25.18},"alignment_confidence":0.9995958209037781},{"word":" they","speaker":null,"time_interval":{"start":25.26,"end":25.36},"alignment_confidence":0.9974746108055115},{"word":" all","speaker":null,"time_interval":{"start":25.36,"end":25.54},"alignment_confidence":0.9924041628837585},{"word":" call","speaker":null,"time_interval":{"start":25.54,"end":25.82},"alignment_confidence":0.9877074360847473},{"word":" me","speaker":null,"time_interval":{"start":25.82,"end":26.08},"alignment_confidence":0.9994006156921387},{"word":" a","speaker":null,"time_interval":{"start":26.08,"end":26.26},"alignment_confidence":0.9883056879043579},{"word":" Yankee","speaker":null,"time_interval":{"start":26.26,"end":26.7},"alignment_confidence":0.9602178335189819},{"word":" down","speaker":null,"time_interval":{"start":26.7,"end":27.02},"alignment_confidence":0.9819284677505493},{"word":" here,","speaker":null,"time_interval":{"start":27.02,"end":27.28},"alignment_confidence":0.9983890056610107},{"word":" so","speaker":null,"time_interval":{"start":27.5,"end":27.58},"alignment_confidence":0.9930232763290405},{"word":" what","speaker":null,"time_interval":{"start":27.58,"end":27.82},"alignment_confidence":0.7675526142120361},{"word":" can","speaker":null,"time_interval":{"start":27.82,"end":28.0},"alignment_confidence":0.41074037551879883},{"word":" I","speaker":null,"time_interval":{"start":28.0,"end":28.22},"alignment_confidence":0.9712743163108826},{"word":" say?","speaker":null,"time_interval":{"start":28.22,"end":28.32},"alignment_confidence":0.9959589838981628}],"speaker":null},{"text":" Oh, I don't hear that in New Jersey now.","time_interval":{"start":28.38,"end":29.82},"confidence":0.8329984157521475,"no_speech_confidence":0.012033582665026188,"words":[{"word":" Oh,","speaker":null,"time_interval":{"start":28.38,"end":28.4},"alignment_confidence":0.8733147978782654},{"word":" I","speaker":null,"time_interval":{"start":28.44,"end":28.54},"alignment_confidence":0.9966418743133545},{"word":" don't","speaker":null,"time_interval":{"start":28.54,"end":28.72},"alignment_confidence":0.9991713762283325},{"word":" hear","speaker":null,"time_interval":{"start":28.72,"end":28.86},"alignment_confidence":0.9954941272735596},{"word":" that","speaker":null,"time_interval":{"start":28.86,"end":29.06},"alignment_confidence":0.9986012578010559},{"word":" in","speaker":null,"time_interval":{"start":29.06,"end":29.16},"alignment_confidence":0.6615467071533203},{"word":" New","speaker":null,"time_interval":{"start":29.16,"end":29.26},"alignment_confidence":0.9980757236480713},{"word":" Jersey","speaker":null,"time_interval":{"start":29.26,"end":29.58},"alignment_confidence":0.9985470175743103},{"word":" now.","speaker":null,"time_interval":{"start":29.58,"end":29.82},"alignment_confidence":0.9120670557022095}],"speaker":null}],"transcription_info":{"language":"en","language_confidence":0.9959872364997864},"transcription":{"text":" Hello. Hello. Oh, hello. I didn't know you were there. Neither did I. Okay, I thought, you know, I heard it deep. This is Diane in New Jersey. And I'm Sheila in Texas, originally from Chicago. Oh, I'm originally from Chicago also. I'm in New Jersey now, though. Well, there is not much difference. At least, you know, they all call me a Yankee down here, so what can I say? Oh, I don't hear that in New Jersey now."}} \ No newline at end of file diff --git a/aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_diar.json b/aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_diar.json new file mode 100644 index 00000000..baf3775d --- /dev/null +++ b/aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_diar.json @@ -0,0 +1,827 @@ +{ + "segments": [ + { + "text": " Hello. Hello.", + "time_interval": { + "start": 6.38, + "end": 7.84 + }, + "confidence": null, + "no_speech_confidence": null, + "words": [ + { + "word": " Hello.", + "speaker": "SPEAKER_01", + "time_interval": { + "start": 6.38, + "end": 7.0 + }, + "alignment_confidence": 0.6853185296058655 + }, + { + "word": " Hello.", + "speaker": "SPEAKER_01", + "time_interval": { + "start": 7.5, + "end": 7.84 + }, + "alignment_confidence": 0.7124693989753723 + } + ], + "speaker": "SPEAKER_01" + }, + { + "text": " Oh, hello. I didn't know you were there.", + "time_interval": { + "start": 8.3, + "end": 9.68 + }, + "confidence": null, + "no_speech_confidence": null, + "words": [ + { + "word": " Oh,", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 8.3, + "end": 8.48 + }, + "alignment_confidence": 0.8500092029571533 + }, + { + "word": " hello.", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 8.5, + "end": 8.76 + }, + "alignment_confidence": 0.9408962726593018 + }, + { + "word": " I", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 8.92, + "end": 8.94 + }, + "alignment_confidence": 0.9970371723175049 + }, + { + "word": " didn't", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 8.94, + "end": 9.14 + }, + "alignment_confidence": 0.9951577484607697 + }, + { + "word": " know", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 9.14, + "end": 9.28 + }, + "alignment_confidence": 0.9988951086997986 + }, + { + "word": " you", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 9.28, + "end": 9.4 + }, + "alignment_confidence": 0.9883798360824585 + }, + { + "word": " were", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 9.4, + "end": 9.5 + }, + "alignment_confidence": 0.9613558053970337 + }, + { + "word": " there.", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 9.5, + "end": 9.68 + }, + "alignment_confidence": 0.9957772493362427 + } + ], + "speaker": "SPEAKER_02" + }, + { + "text": " Neither did I.", + "time_interval": { + "start": 9.98, + "end": 10.64 + }, + "confidence": null, + "no_speech_confidence": null, + "words": [ + { + "word": " Neither", + "speaker": "SPEAKER_01", + "time_interval": { + "start": 9.98, + "end": 10.22 + }, + "alignment_confidence": 0.7853943109512329 + }, + { + "word": " did", + "speaker": "SPEAKER_01", + "time_interval": { + "start": 10.22, + "end": 10.56 + }, + "alignment_confidence": 0.991905689239502 + }, + { + "word": " I.", + "speaker": "SPEAKER_01", + "time_interval": { + "start": 10.56, + "end": 10.64 + }, + "alignment_confidence": 0.9961293935775757 + } + ], + "speaker": "SPEAKER_01" + }, + { + "text": " Okay, I thought, you know, I heard it deep. This is Diane in New Jersey.", + "time_interval": { + "start": 10.64, + "end": 13.94 + }, + "confidence": null, + "no_speech_confidence": null, + "words": [ + { + "word": " Okay,", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 10.64, + "end": 11.0 + }, + "alignment_confidence": 0.5810848474502563 + }, + { + "word": " I", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 11.1, + "end": 11.3 + }, + "alignment_confidence": 0.8348276615142822 + }, + { + "word": " thought,", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 11.3, + "end": 11.52 + }, + "alignment_confidence": 0.8927552700042725 + }, + { + "word": " you", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 11.7, + "end": 11.82 + }, + "alignment_confidence": 0.9182931780815125 + }, + { + "word": " know,", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 11.82, + "end": 11.9 + }, + "alignment_confidence": 0.9988963603973389 + }, + { + "word": " I", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 11.96, + "end": 11.98 + }, + "alignment_confidence": 0.9913604855537415 + }, + { + "word": " heard", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 11.98, + "end": 12.12 + }, + "alignment_confidence": 0.992224931716919 + }, + { + "word": " it", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 12.12, + "end": 12.22 + }, + "alignment_confidence": 0.41407883167266846 + }, + { + "word": " deep.", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 12.22, + "end": 12.38 + }, + "alignment_confidence": 0.7947677969932556 + }, + { + "word": " This", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 12.56, + "end": 12.66 + }, + "alignment_confidence": 0.9929931163787842 + }, + { + "word": " is", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 12.66, + "end": 12.82 + }, + "alignment_confidence": 0.9977980852127075 + }, + { + "word": " Diane", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 12.82, + "end": 13.12 + }, + "alignment_confidence": 0.884567379951477 + }, + { + "word": " in", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 13.12, + "end": 13.4 + }, + "alignment_confidence": 0.9688433408737183 + }, + { + "word": " New", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 13.4, + "end": 13.54 + }, + "alignment_confidence": 0.9953000545501709 + }, + { + "word": " Jersey.", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 13.54, + "end": 13.94 + }, + "alignment_confidence": 0.9994387030601501 + } + ], + "speaker": "SPEAKER_02" + }, + { + "text": " And I'm Sheila in Texas, originally from Chicago.", + "time_interval": { + "start": 14.36, + "end": 17.42 + }, + "confidence": null, + "no_speech_confidence": null, + "words": [ + { + "word": " And", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 14.36, + "end": 14.52 + }, + "alignment_confidence": 0.9156690239906311 + }, + { + "word": " I'm", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 14.52, + "end": 14.72 + }, + "alignment_confidence": 0.9674701690673828 + }, + { + "word": " Sheila", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 14.72, + "end": 15.06 + }, + "alignment_confidence": 0.967194139957428 + }, + { + "word": " in", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 15.06, + "end": 15.5 + }, + "alignment_confidence": 0.9479336738586426 + }, + { + "word": " Texas,", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 15.5, + "end": 15.98 + }, + "alignment_confidence": 0.9921407699584961 + }, + { + "word": " originally", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 16.16, + "end": 16.7 + }, + "alignment_confidence": 0.998430073261261 + }, + { + "word": " from", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 16.7, + "end": 17.08 + }, + "alignment_confidence": 0.9941904544830322 + }, + { + "word": " Chicago.", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 17.08, + "end": 17.42 + }, + "alignment_confidence": 0.9997063279151917 + } + ], + "speaker": "SPEAKER_00" + }, + { + "text": " Oh, I'm originally from Chicago also. I'm in New Jersey now, though.", + "time_interval": { + "start": 18.04, + "end": 21.34 + }, + "confidence": null, + "no_speech_confidence": null, + "words": [ + { + "word": " Oh,", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 18.04, + "end": 18.36 + }, + "alignment_confidence": 0.8741865754127502 + }, + { + "word": " I'm", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 18.4, + "end": 18.62 + }, + "alignment_confidence": 0.9828968346118927 + }, + { + "word": " originally", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 18.62, + "end": 18.86 + }, + "alignment_confidence": 0.9990705847740173 + }, + { + "word": " from", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 18.86, + "end": 19.16 + }, + "alignment_confidence": 0.9991160035133362 + }, + { + "word": " Chicago", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 19.16, + "end": 19.56 + }, + "alignment_confidence": 0.9996306896209717 + }, + { + "word": " also.", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 19.56, + "end": 19.98 + }, + "alignment_confidence": 0.8224349021911621 + }, + { + "word": " I'm", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 20.18, + "end": 20.24 + }, + "alignment_confidence": 0.9975096881389618 + }, + { + "word": " in", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 20.24, + "end": 20.34 + }, + "alignment_confidence": 0.9985995888710022 + }, + { + "word": " New", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 20.34, + "end": 20.46 + }, + "alignment_confidence": 0.9962742328643799 + }, + { + "word": " Jersey", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 20.46, + "end": 20.76 + }, + "alignment_confidence": 0.99826979637146 + }, + { + "word": " now,", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 20.76, + "end": 21.14 + }, + "alignment_confidence": 0.9930222630500793 + }, + { + "word": " though.", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 21.22, + "end": 21.34 + }, + "alignment_confidence": 0.9972519278526306 + } + ], + "speaker": "SPEAKER_02" + }, + { + "text": " Well, there is not much difference. At least, you know, they all call me a Yankee down here, so what can I say?", + "time_interval": { + "start": 21.74, + "end": 28.32 + }, + "confidence": null, + "no_speech_confidence": null, + "words": [ + { + "word": " Well,", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 21.74, + "end": 22.14 + }, + "alignment_confidence": 0.9893079400062561 + }, + { + "word": " there", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 22.26, + "end": 22.54 + }, + "alignment_confidence": 0.9968032836914062 + }, + { + "word": " is", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 22.54, + "end": 22.7 + }, + "alignment_confidence": 0.45982280373573303 + }, + { + "word": " not", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 22.7, + "end": 22.86 + }, + "alignment_confidence": 0.9138479828834534 + }, + { + "word": " much", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 22.86, + "end": 23.24 + }, + "alignment_confidence": 0.9922243356704712 + }, + { + "word": " difference.", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 23.24, + "end": 23.76 + }, + "alignment_confidence": 0.9899066090583801 + }, + { + "word": " At", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 24.06, + "end": 24.1 + }, + "alignment_confidence": 0.9863474369049072 + }, + { + "word": " least,", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 24.1, + "end": 24.36 + }, + "alignment_confidence": 0.9989750385284424 + }, + { + "word": " you", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 24.56, + "end": 25.02 + }, + "alignment_confidence": 0.9212563037872314 + }, + { + "word": " know,", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 25.02, + "end": 25.18 + }, + "alignment_confidence": 0.9995958209037781 + }, + { + "word": " they", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 25.26, + "end": 25.36 + }, + "alignment_confidence": 0.9974746108055115 + }, + { + "word": " all", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 25.36, + "end": 25.54 + }, + "alignment_confidence": 0.9924041628837585 + }, + { + "word": " call", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 25.54, + "end": 25.82 + }, + "alignment_confidence": 0.9877074360847473 + }, + { + "word": " me", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 25.82, + "end": 26.08 + }, + "alignment_confidence": 0.9994006156921387 + }, + { + "word": " a", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 26.08, + "end": 26.26 + }, + "alignment_confidence": 0.9883056879043579 + }, + { + "word": " Yankee", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 26.26, + "end": 26.7 + }, + "alignment_confidence": 0.9602178335189819 + }, + { + "word": " down", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 26.7, + "end": 27.02 + }, + "alignment_confidence": 0.9819284677505493 + }, + { + "word": " here,", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 27.02, + "end": 27.28 + }, + "alignment_confidence": 0.9983890056610107 + }, + { + "word": " so", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 27.5, + "end": 27.58 + }, + "alignment_confidence": 0.9930232763290405 + }, + { + "word": " what", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 27.58, + "end": 27.82 + }, + "alignment_confidence": 0.7675526142120361 + }, + { + "word": " can", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 27.82, + "end": 28.0 + }, + "alignment_confidence": 0.41074037551879883 + }, + { + "word": " I", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 28.0, + "end": 28.22 + }, + "alignment_confidence": 0.9712743163108826 + }, + { + "word": " say?", + "speaker": "SPEAKER_00", + "time_interval": { + "start": 28.22, + "end": 28.32 + }, + "alignment_confidence": 0.9959589838981628 + } + ], + "speaker": "SPEAKER_00" + }, + { + "text": " Oh, I don't hear that in New Jersey now.", + "time_interval": { + "start": 28.38, + "end": 29.82 + }, + "confidence": null, + "no_speech_confidence": null, + "words": [ + { + "word": " Oh,", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 28.38, + "end": 28.4 + }, + "alignment_confidence": 0.8733147978782654 + }, + { + "word": " I", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 28.44, + "end": 28.54 + }, + "alignment_confidence": 0.9966418743133545 + }, + { + "word": " don't", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 28.54, + "end": 28.72 + }, + "alignment_confidence": 0.9991713762283325 + }, + { + "word": " hear", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 28.72, + "end": 28.86 + }, + "alignment_confidence": 0.9954941272735596 + }, + { + "word": " that", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 28.86, + "end": 29.06 + }, + "alignment_confidence": 0.9986012578010559 + }, + { + "word": " in", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 29.06, + "end": 29.16 + }, + "alignment_confidence": 0.6615467071533203 + }, + { + "word": " New", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 29.16, + "end": 29.26 + }, + "alignment_confidence": 0.9980757236480713 + }, + { + "word": " Jersey", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 29.26, + "end": 29.58 + }, + "alignment_confidence": 0.9985470175743103 + }, + { + "word": " now.", + "speaker": "SPEAKER_02", + "time_interval": { + "start": 29.58, + "end": 29.82 + }, + "alignment_confidence": 0.9120670557022095 + } + ], + "speaker": "SPEAKER_02" + } + ], + "transcription_info": { + "language": "en", + "language_confidence": 0.9959872364997864 + }, + "transcription": { + "text": " Hello. Hello. Oh, hello. I didn't know you were there. Neither did I. Okay, I thought, you know, I heard it deep. This is Diane in New Jersey. And I'm Sheila in Texas, originally from Chicago. Oh, I'm originally from Chicago also. I'm in New Jersey now, though. Well, there is not much difference. At least, you know, they all call me a Yankee down here, so what can I say? Oh, I don't hear that in New Jersey now." + } +} \ No newline at end of file diff --git a/aana/tests/units/test_speaker.py b/aana/tests/units/test_speaker.py new file mode 100644 index 00000000..eb72a3d4 --- /dev/null +++ b/aana/tests/units/test_speaker.py @@ -0,0 +1,67 @@ +# ruff: noqa: S101 +import json +from importlib import resources +from pathlib import Path +from typing import Literal + +import pytest + +from aana.core.models.asr import ( + AsrSegment, + AsrTranscription, + AsrTranscriptionInfo, +) +from aana.core.models.speaker import SpeakerDiarizationSegment +from aana.processors.speaker import ( + SpeakerDiarizationOutput, + WhisperOutput, + asr_postprocessing_for_diarization, +) +from aana.tests.utils import verify_deployment_results + + +@pytest.mark.parametrize("audio_file", ["sd_sample.wav"]) +def test_asr_diarization_post_process(audio_file: Literal["sd_sample.wav"]): + """Test that the ASR output can be processed to generate diarized transcription.""" + # load precomputed asr and diarization outputs + asr_path = ( + resources.files("aana.tests.files.expected.whisper") + / f"whisper_medium_{audio_file}.json" + ) + diar_path = ( + resources.files("aana.tests.files.expected.sd") + / f"{Path(audio_file).stem}.json" + ) + expected_results_path = ( + resources.files("aana.tests.files.expected.whisper") + / f"whisper_medium_{audio_file}_diar.json" + ) + # convert to WhisperOutput and SpeakerDiarizationOutput + + with Path.open(asr_path, "r") as json_file: + asr_op = json.load(json_file) + + asr_output = WhisperOutput( + segments=[AsrSegment.model_validate(segment) for segment in asr_op["segments"]], + transcription_info=AsrTranscriptionInfo.model_validate( + asr_op["transcription_info"] + ), + transcription=AsrTranscription.model_validate(asr_op["transcription"]), + ) + + with Path.open(diar_path, "r") as json_file: + diar_op = json.load(json_file) + + diar_output = SpeakerDiarizationOutput( + segments=[ + SpeakerDiarizationSegment.model_validate(segment) + for segment in diar_op["segments"] + ], + ) + + processed_transcription = asr_postprocessing_for_diarization( + diar_output, + asr_output, + ) + + verify_deployment_results(expected_results_path, processed_transcription) diff --git a/notebooks/diarized_transcription_example.ipynb b/notebooks/diarized_transcription_example.ipynb new file mode 100644 index 00000000..3d0584e0 --- /dev/null +++ b/notebooks/diarized_transcription_example.ipynb @@ -0,0 +1,292 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Build an App for diarized transcription Using Aana SDK\n", + "\n", + "This notebook provides an example of getting diarized transcription from video. Please note that the pyannote diarization model is a gated model. Follow [speaker diarization deployment docs](./../docs/pages/model_hub/speaker_recognition.md) to get access to the model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As a first step, set the environment and install aana SDK" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n", + "env: CUDA_VISIBLE_DEVICES=0\n" + ] + } + ], + "source": [ + "%pip install aana -qqqU \n", + "%env CUDA_VISIBLE_DEVICES=0\n", + "# %env HF_TOKEN=" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define Whisper and PyannoteSpeakerDiarization deployments, define the TranscribeVideoWithDiarEndpoint for diarized transcription. Register deployments and the endpoints." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from aana.api.api_generation import Endpoint\n", + "from aana.core.models.base import pydantic_to_dict\n", + "from aana.core.models.speaker import PyannoteSpeakerDiarizationParams\n", + "from aana.core.models.video import VideoInput\n", + "from aana.core.models.whisper import WhisperParams\n", + "from aana.deployments.aana_deployment_handle import AanaDeploymentHandle\n", + "from aana.deployments.pyannote_speaker_diarization_deployment import (\n", + " PyannoteSpeakerDiarizationConfig,\n", + " PyannoteSpeakerDiarizationDeployment,\n", + ")\n", + "from aana.deployments.whisper_deployment import (\n", + " WhisperComputeType,\n", + " WhisperConfig,\n", + " WhisperDeployment,\n", + " WhisperModelSize,\n", + " WhisperOutput,\n", + ")\n", + "from aana.integrations.external.yt_dlp import download_video\n", + "from aana.processors.remote import run_remote\n", + "from aana.processors.speaker import asr_postprocessing_for_diarization\n", + "from aana.processors.video import extract_audio\n", + "from aana.sdk import AanaSDK\n", + "\n", + "# Define the model deployments.\n", + "asr_deployment = WhisperDeployment.options(\n", + " num_replicas=1,\n", + " ray_actor_options={\n", + " \"num_gpus\": 0.25\n", + " }, # Remove this line if you want to run Whisper on a CPU.# Also change type to float32.\n", + " user_config=WhisperConfig(\n", + " model_size=WhisperModelSize.SMALL,\n", + " compute_type=WhisperComputeType.FLOAT16,\n", + " ).model_dump(mode=\"json\"),\n", + ")\n", + "diarization_deployment = PyannoteSpeakerDiarizationDeployment.options(\n", + " num_replicas=1,\n", + " ray_actor_options={\n", + " \"num_gpus\": 0.1\n", + " }, # Remove this line if you want to run the model on a CPU.\n", + " user_config=PyannoteSpeakerDiarizationConfig(\n", + " model_id=\"pyannote/speaker-diarization-3.1\"\n", + " ).model_dump(mode=\"json\"),\n", + ")\n", + "deployments = [\n", + " {\"name\": \"asr_deployment\", \"instance\": asr_deployment},\n", + " {\"name\": \"diarization_deployment\", \"instance\": diarization_deployment},\n", + "]\n", + "\n", + "\n", + "# Define the endpoint to transcribe the video with diarization.\n", + "class TranscribeVideoWithDiarEndpoint(Endpoint):\n", + " \"\"\"Transcribe video with diarization endpoint.\"\"\"\n", + "\n", + " async def initialize(self):\n", + " \"\"\"Initialize the endpoint.\"\"\"\n", + " self.asr_handle = await AanaDeploymentHandle.create(\"asr_deployment\")\n", + " self.diar_handle = await AanaDeploymentHandle.create(\"diarization_deployment\")\n", + " await super().initialize()\n", + "\n", + " async def run(\n", + " self,\n", + " video: VideoInput,\n", + " whisper_params: WhisperParams,\n", + " diar_params: PyannoteSpeakerDiarizationParams,\n", + " ) -> WhisperOutput:\n", + " \"\"\"Transcribe video with diarization.\"\"\"\n", + " video_obj = await run_remote(download_video)(video_input=video)\n", + " audio = extract_audio(video=video_obj)\n", + "\n", + " # diarized transcript requires word_timestamps from ASR\n", + " whisper_params.word_timestamps = True\n", + " transcription = await self.asr_handle.transcribe(\n", + " audio=audio, params=whisper_params\n", + " )\n", + " diarized_output = await self.diar_handle.diarize(\n", + " audio=audio, params=diar_params\n", + " )\n", + " transcription = asr_postprocessing_for_diarization(\n", + " diarized_output, transcription\n", + " )\n", + " output = pydantic_to_dict(transcription)\n", + "\n", + " output_keys = [\"time_interval\", \"speaker\", \"text\"]\n", + " filtered_data = [\n", + " {k: v for k, v in entry.items() if k in output_keys}\n", + " for entry in output[\"segments\"]\n", + " ]\n", + "\n", + " return filtered_data\n", + "\n", + "\n", + "endpoints = [\n", + " {\n", + " \"name\": \"transcribe_video\",\n", + " \"path\": \"/video/transcribe\",\n", + " \"summary\": \"Transcribe a video\",\n", + " \"endpoint_cls\": TranscribeVideoWithDiarEndpoint,\n", + " },\n", + "]\n", + "\n", + "aana_app = AanaSDK(name=\"transcribe_video_app\")\n", + "\n", + "for deployment in deployments:\n", + " aana_app.register_deployment(**deployment)\n", + "\n", + "for endpoint in endpoints:\n", + " aana_app.register_endpoint(**endpoint)\n", + "\n", + "aana_app.connect(\n", + " host=\"127.0.0.1\", port=8000, show_logs=False\n", + ") # Connects to the Ray cluster or starts a new one.\n", + "aana_app.migrate() # Runs the migrations to create the database tables." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Start the App!" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Deployed successfully.\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[32mDeployed successfully.\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Documentation is available at http://127.0.0.1:8000/docs and http://127.0.0.1:8000/redoc\n",
+       "
\n" + ], + "text/plain": [ + "Documentation is available at \u001b]8;id=135178;http://127.0.0.1:8000/docs\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/docs\u001b[0m\u001b]8;;\u001b\\ and \u001b]8;id=716794;http://127.0.0.1:8000/redoc\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/redoc\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "aana_app.deploy()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "aana_app.shutdown()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have the app running, lets provide an example audio with multiple speakers for transcription." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'error': 'FileNotFoundError', 'message': \"[Errno 2] No such file or directory: 'aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_speaker.json'\", 'data': {}, 'stacktrace': 'Traceback (most recent call last):\\n File \"/workspaces/aana_sdk/aana/api/api_generation.py\", line 337, in route_func_body\\n output = await self.run(**data_dict)\\n File \"/tmp/ipykernel_1758291/2995001203.py\", line 78, in run\\n File \"/workspaces/aana_sdk/aana/processors/speaker.py\", line 411, in asr_postprocessing_for_diarization\\n speaker_labelled_transcription = assign_word_speakers(\\n File \"/workspaces/aana_sdk/aana/processors/speaker.py\", line 87, in assign_word_speakers\\n with open(expected_output_path, \"w\") as json_file:\\nFileNotFoundError: [Errno 2] No such file or directory: \\'aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_speaker.json\\'\\n'}\n" + ] + } + ], + "source": [ + "import json\n", + "\n", + "import requests\n", + "\n", + "video = {\n", + " \"path\": \"../aana/tests/files/audios/sd_sample.wav\", # Video URL, Aana SDK supports URLs (including YouTube), file paths or even raw video data\n", + " \"media_id\": \"sd_sample\", # Media ID, so we can ask questions about the video later by using this ID\n", + "}\n", + "\n", + "data = {\n", + " \"whisper_params\": {\n", + " \"word_timestamps\": True, # Enable word_timestamps\n", + " },\n", + " \"video\": video,\n", + "}\n", + "\n", + "url = \"http://127.0.0.1:8000/video/transcribe\"\n", + "\n", + "# No streaming support possible for diarized transcription as it needs complete ASR output beforehand.\n", + "response = requests.post(url, data={\"body\": json.dumps(data)})\n", + "\n", + "print(response.json())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Each transcribed segment comes with a corresponding speaker label as well!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aana-vIr3-B0u-py3.10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 8e4c433dc5c064e2c881f6242fd0a74fbaf18bc6 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Mon, 16 Sep 2024 09:16:29 +0000 Subject: [PATCH 03/18] fixing notebook --- .../diarized_transcription_example.ipynb | 52 ++++++++----------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/notebooks/diarized_transcription_example.ipynb b/notebooks/diarized_transcription_example.ipynb index 3d0584e0..e5d5a5be 100644 --- a/notebooks/diarized_transcription_example.ipynb +++ b/notebooks/diarized_transcription_example.ipynb @@ -18,22 +18,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 16, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Note: you may need to restart the kernel to use updated packages.\n", - "env: CUDA_VISIBLE_DEVICES=0\n" - ] - } - ], + "outputs": [], "source": [ - "%pip install aana -qqqU \n", - "%env CUDA_VISIBLE_DEVICES=0\n", - "# %env HF_TOKEN=" + "import os\n", + "\n", + "os.environ[\"HF_TOKEN\"] = \"YOUR_HF_TOKEN_GOES_HERE\"\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" ] }, { @@ -45,9 +37,18 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO [alembic.runtime.migration] Context impl SQLiteImpl.\n", + "INFO [alembic.runtime.migration] Will assume non-transactional DDL.\n" + ] + } + ], "source": [ "from aana.api.api_generation import Endpoint\n", "from aana.core.models.base import pydantic_to_dict\n", @@ -137,7 +138,7 @@ " for entry in output[\"segments\"]\n", " ]\n", "\n", - " return filtered_data\n", + " return {\"segments\": filtered_data}\n", "\n", "\n", "endpoints = [\n", @@ -172,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -195,7 +196,7 @@ "\n" ], "text/plain": [ - "Documentation is available at \u001b]8;id=135178;http://127.0.0.1:8000/docs\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/docs\u001b[0m\u001b]8;;\u001b\\ and \u001b]8;id=716794;http://127.0.0.1:8000/redoc\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/redoc\u001b[0m\u001b]8;;\u001b\\\n" + "Documentation is available at \u001b]8;id=337020;http://127.0.0.1:8000/docs\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/docs\u001b[0m\u001b]8;;\u001b\\ and \u001b]8;id=418863;http://127.0.0.1:8000/redoc\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/redoc\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, @@ -206,15 +207,6 @@ "aana_app.deploy()" ] }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "aana_app.shutdown()" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -224,14 +216,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'error': 'FileNotFoundError', 'message': \"[Errno 2] No such file or directory: 'aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_speaker.json'\", 'data': {}, 'stacktrace': 'Traceback (most recent call last):\\n File \"/workspaces/aana_sdk/aana/api/api_generation.py\", line 337, in route_func_body\\n output = await self.run(**data_dict)\\n File \"/tmp/ipykernel_1758291/2995001203.py\", line 78, in run\\n File \"/workspaces/aana_sdk/aana/processors/speaker.py\", line 411, in asr_postprocessing_for_diarization\\n speaker_labelled_transcription = assign_word_speakers(\\n File \"/workspaces/aana_sdk/aana/processors/speaker.py\", line 87, in assign_word_speakers\\n with open(expected_output_path, \"w\") as json_file:\\nFileNotFoundError: [Errno 2] No such file or directory: \\'aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_speaker.json\\'\\n'}\n" + "{'segments': [{'text': ' Hello. Hello.', 'time_interval': {'start': 6.38, 'end': 7.84}, 'speaker': 'SPEAKER_01'}, {'text': \" Oh, hello. I didn't know you were there.\", 'time_interval': {'start': 8.3, 'end': 9.68}, 'speaker': 'SPEAKER_02'}, {'text': ' Neither did I.', 'time_interval': {'start': 9.98, 'end': 10.64}, 'speaker': 'SPEAKER_01'}, {'text': ' Okay, I thought, you know, I heard it deep. This is Diane in New Jersey.', 'time_interval': {'start': 10.64, 'end': 13.94}, 'speaker': 'SPEAKER_02'}, {'text': \" And I'm Sheila in Texas, originally from Chicago.\", 'time_interval': {'start': 14.36, 'end': 17.42}, 'speaker': 'SPEAKER_00'}, {'text': \" Oh, I'm originally from Chicago also. I'm in New Jersey now, though.\", 'time_interval': {'start': 18.04, 'end': 21.34}, 'speaker': 'SPEAKER_02'}, {'text': ' Well, there is not much difference. At least, you know, they all call me a Yankee down here, so what can I say?', 'time_interval': {'start': 21.74, 'end': 28.32}, 'speaker': 'SPEAKER_00'}, {'text': \" Oh, I don't hear that in New Jersey now.\", 'time_interval': {'start': 28.38, 'end': 29.82}, 'speaker': 'SPEAKER_02'}]}\n" ] } ], From 1c933f635fa522bd729dae91b05121f4ef4ba5af Mon Sep 17 00:00:00 2001 From: jiltseb Date: Thu, 19 Sep 2024 11:33:17 +0000 Subject: [PATCH 04/18] fixes, modified post processing for segment length, confidence and no_speech_confidence values --- aana/processors/speaker.py | 316 ++++++++++++++---- ...pyannote_speaker_diarization_deployment.py | 6 +- .../whisper_medium_sd_sample.wav_diar.json | 32 +- aana/tests/units/test_speaker.py | 6 +- 4 files changed, 266 insertions(+), 94 deletions(-) diff --git a/aana/processors/speaker.py b/aana/processors/speaker.py index 0df24b26..b90ddebd 100644 --- a/aana/processors/speaker.py +++ b/aana/processors/speaker.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import TypedDict +from typing import TypedDict, Union from aana.core.models.asr import ( AsrSegment, @@ -10,10 +10,10 @@ from aana.core.models.speaker import SpeakerDiarizationSegment from aana.core.models.time import TimeInterval -# Utility functions for speaker-related processing +# Utility functions for speaker-related processing in audio -# redefine SpeakerDiarizationOutput and WhisperOutput to prevent circular imports +# Redefine SpeakerDiarizationOutput and WhisperOutput to prevent circular imports class SpeakerDiarizationOutput(TypedDict): """The output of the Speaker Diarization model. @@ -38,11 +38,10 @@ class WhisperOutput(TypedDict): transcription: AsrTranscription -# Define sentence ending punctuations: +# Define sentence ending punctuations to split segments at sentence endings: sentence_ending_punctuations = ".?!" -# AsrSegment and AsrWord has a speaker label that defaults to None. def assign_word_speakers( diarized_output: SpeakerDiarizationOutput, transcription: WhisperOutput, @@ -50,13 +49,13 @@ def assign_word_speakers( ) -> WhisperOutput: """Assigns speaker labels to each segment and word in the transcription based on diarized output. - Parameters: - - diarized_output (SpeakerDiarizationOutput): Contains speaker diarization segments. - - transcription (WhisperOutput): Transcription data with segments, text, and language_info. - - fill_nearest (bool): If True, assigns the closest speaker even if there's no positive overlap. Default is False. + Args: + diarized_output (SpeakerDiarizationOutput): Contains speaker diarization segments. + transcription (WhisperOutput): Transcription data with segments, text, and language_info. + fill_nearest (bool): If True, assigns the closest speaker even if there's no positive overlap. Default is False. Returns: - - transcription (WhisperOutput): Transcription updated in-place with the assigned speaker labels. + transcription (WhisperOutput): Transcription updated in-place with the assigned speaker labels. """ for segment in transcription["segments"]: # Assign speaker to segment @@ -88,14 +87,14 @@ def get_speaker_for_interval( ) -> str | None: """Determines the speaker for a given time interval based on diarized segments. - Parameters: - - sd_segments (list[SpeakerDiarizationSegment]): List of speaker diarization segments. - - start_time (float): Start time of the interval. - - end_time (float): End time of the interval. - - fill_nearest (bool): If True, selects the closest speaker even with no overlap. + Args: + sd_segments (list[SpeakerDiarizationSegment]): List of speaker diarization segments. + start_time (float): Start time of the interval. + end_time (float): End time of the interval. + fill_nearest (bool): If True, selects the closest speaker even with no overlap. Returns: - - str | None: The identified speaker label, or None if no speaker is found. + str | None: The identified speaker label, or None if no speaker is found. """ overlaps = [] @@ -184,13 +183,13 @@ def find_nearest_speaker( ) -> str | None: """Find the nearest speaker label in the word_speaker_mapping either forward or backward. - Parameters: - - index (int): The index to start searching from. - - word_speaker_mapping (list[AsrWord]): List of word-speaker mappings. - - reverse (bool): Search backwards if True; forwards if False. Default is False. + Args: + index (int): The index to start searching from. + word_speaker_mapping (list[AsrWord]): List of word-speaker mappings. + reverse (bool): Search backwards if True; forwards if False. Default is False. Returns: - - str | None: The nearest speaker found or None if not found. + str | None: The nearest speaker found or None if not found. """ step = -1 if reverse else 1 for i in range(index, len(word_speaker_mapping) if not reverse else -1, step): @@ -204,12 +203,12 @@ def align_with_punctuation( ) -> list[AsrWord]: """Aligns speaker labels with sentence boundaries defined by punctuation. - Parameters: - - transcription (WhisperOutput): transcription with speaker information. - - max_words_in_sentence (int): Maximum number of words allowed in a sentence. + Args: + transcription (WhisperOutput): transcription with speaker information. + max_words_in_sentence (int): Maximum number of words allowed in a sentence. Returns: - - list[AsrWord]: Realigned word-speaker mappings. + word_speaker_mapping: (list[AsrWord]): Realigned word-speaker mappings. """ new_segments = [segment.words for segment in transcription["segments"]] word_speaker_mapping = [word for segment in new_segments for word in segment] @@ -259,76 +258,232 @@ def align_with_punctuation( return word_speaker_mapping +def create_new_segment( + word_info: AsrWord, speaker: str | None, is_empty: bool = False +) -> AsrSegment: + """Creates a new segment based on word information. + + Args: + word_info (AsrWord): The word information containing text, timing, etc. + speaker (str | None): The speaker associated with this word. + is_empty (bool): If True, creates an empty segment (for punctuation-only segments). + + Returns: + AsrSegment: A new segment with the provided word information and speaker details. + """ + return AsrSegment( + time_interval=TimeInterval( + start=word_info.time_interval.start + if not is_empty + else word_info.time_interval.end, + end=word_info.time_interval.end, + ), + text=word_info.word if not is_empty else "", + speaker=speaker, + words=[word_info] if not is_empty else [], + confidence=None, + no_speech_confidence=None, + ) + + def create_speaker_segments( - word_list: list[AsrWord], merge: bool = False + word_list: list[AsrWord], max_words_per_segment: int = 50 ) -> list[AsrSegment]: - """Creates speaker segments from a list of words with speaker annotations and timing information. + """Creates speaker segments from a list of words with speaker annotations. - Parameters: - - word_list (List[AsrWord]): A list of words with associated speaker and timing details. - - merge (bool): If True, merges consecutive segments from the same speaker into one segment. + Args: + word_list (List[AsrWord]): A list of words with associated speaker and timing details. + max_words_per_segment (int): The maximum number of words per segment. If the segment exceeds this, + it will be split at previous or next sentence-ending punctuation. Returns: - - List[AsrSegment]: A list of segments where each segment groups words spoken by the same speaker. + List[AsrSegment]: A list of segments where each segment groups words spoken by the same speaker. """ if not word_list: return [] - # Initialize variables to track current speaker and segment current_speaker = None current_segment = None final_segments: list[AsrSegment] = [] + word_count = 0 for word_info in word_list: - # Default speaker assignment to the last known speaker if missing speaker = word_info.speaker or current_speaker - # Check for speaker change or start of a new segment + # Handle speaker change if speaker != current_speaker: if current_segment: final_segments.append(current_segment) - - # Start a new segment - # TODO: Also get the original confidence measurements from segments and update it. - current_segment = AsrSegment( - time_interval=TimeInterval( - start=word_info.time_interval.start, end=word_info.time_interval.end - ), - text=word_info.word, - speaker=speaker, - words=[word_info], - confidence=None, - no_speech_confidence=None, - ) + current_segment = create_new_segment(word_info, speaker) current_speaker = speaker + word_count = 1 else: if current_segment: - # Update the current segment for continuous speaker - current_segment.time_interval.end = word_info.time_interval.end - current_segment.text += f"{word_info.word}" # Add space between words - current_segment.words.append(word_info) + # Handle word count and punctuation splitting + current_segment, word_count = split_segment_on_length_punctuation( + current_segment, + word_info, + word_count, + max_words_per_segment, + final_segments, + ) - # Append the last segment if it exists - if current_segment: + # Add the final segment if it exists + if current_segment and current_segment.words: final_segments.append(current_segment) - # Optional merging of consecutive segments by the same speaker - if merge: - final_segments = merge_consecutive_speaker_segments(final_segments) - return final_segments +def add_segment_variables( + segments: list[AsrSegment], transcription: WhisperOutput +) -> list[AsrSegment]: + """Adds confidence and no_speech_confidence variables to each segment. + + Args: + segments (List[AsrSegment]): A list of segments to which the confidence values will be added. + transcription (WhisperOutput): The transcription data to help determine segment confidence. + + Returns: + List[AsrSegment]: Segments with confidence and no_speech_confidence added. + """ + for segment in segments: + confidence, no_speech_confidence = determine_major_segment_confidence( + segment, transcription + ) + segment.confidence = confidence + segment.no_speech_confidence = no_speech_confidence + return segments + + +def split_segment_on_length_punctuation( + current_segment: AsrSegment, + word_info: AsrWord, + word_count: int, + max_words_per_segment: int, + final_segments: list[AsrSegment], +) -> tuple[AsrSegment, int]: + """Splits segments based on length and sentence-ending punctuation. + + Args: + current_segment (AsrSegment): The current speaker segment being processed. + word_info (AsrWord): Word information containing timing and text. + word_count (int): The current word count in the segment. + max_words_per_segment (int): Maximum number of words allowed in a segment before splitting. + final_segments (List[AsrSegment]): List of segments to which the completed segment will be added. + + Returns: + Tuple[AsrSegment, int]: The updated segment and word count. + """ + # Check if word count exceeds the limit and if punctuation exists to split + if word_count >= max_words_per_segment and any( + p in word_info.word for p in sentence_ending_punctuations + ): + # update current segment and then append it + current_segment.time_interval.end = word_info.time_interval.end + current_segment.text += f"{word_info.word}" + current_segment.words.append(word_info) + final_segments.append(current_segment) + current_segment = create_new_segment( + word_info, current_segment.speaker, is_empty=True + ) + word_count = 0 # Reset word count + + else: + # Append word to the current segment + current_segment.time_interval.end = word_info.time_interval.end + current_segment.text += f"{word_info.word}" + current_segment.words.append(word_info) + word_count += 1 + + # If sentence-ending punctuation is found, finalize the segment + # if any(p in word_info.word for p in sentence_ending_punctuations): + # final_segments.append(current_segment) + # current_segment = create_new_segment( + # word_info, current_segment.speaker, is_empty=True + # ) + # word_count = 0 # Reset word count after punctuation + + return current_segment, word_count + + +def determine_major_segment_confidence( + segment: AsrSegment, transcription: WhisperOutput +) -> tuple[float | None, float | None]: + """Determines the confidence and no_speech_confidence based on the major segment (which contributes the most time or words). + + Args: + segment (AsrSegment): New ASR segment. + transcription (WhisperOutput): Original transcription containing segments with confidence. + + Returns: + tuple[Optional[float], Optional[float]]: Confidence and no_speech_confidence from the major segment. + """ + + def find_closest_segment(word_start: float, word_end: float) -> AsrSegment | None: + """Finds the closest segment in the transcription for the given word start and end times.""" + closest_segment = min( + transcription["segments"], + key=lambda segment: abs(segment.time_interval.start - word_start) + + abs(segment.time_interval.end - word_end), + default=None, + ) + return closest_segment + + def update_segment_contribution( + contributions: dict, segment: AsrSegment, word_duration: float + ) -> None: + """Updates the contribution data for the given segment.""" + segment_id = id(segment) + if segment_id not in contributions: + contributions[segment_id] = { + "segment": segment, + "contribution_time": 0.0, + "word_count": 0, + } + contributions[segment_id]["contribution_time"] += word_duration + contributions[segment_id]["word_count"] += 1 + + segment_contributions: defaultdict = defaultdict( + lambda: {"segment": None, "contribution_time": 0.0, "word_count": 0} + ) + + for word in segment.words: + word_start, word_end = word.time_interval.start, word.time_interval.end + word_duration = word_end - word_start + + closest_segment = find_closest_segment(word_start, word_end) + + if closest_segment: + update_segment_contribution( + segment_contributions, closest_segment, word_duration + ) + + if not segment_contributions: + return None, None + + # Determine the segment with the highest word count or contribution time + major_segment_data = max( + segment_contributions.values(), + key=lambda data: data[ + "word_count" + ], # Change this to 'contribution_time' if needed + ) + + major_segment = major_segment_data["segment"] + return major_segment.confidence, major_segment.no_speech_confidence + + def merge_consecutive_speaker_segments( segments: list[AsrSegment], ) -> list[AsrSegment]: """Merges consecutive segments that have the same speaker into a single segment. - Parameters: - - segments (List[AsrSegment]): The initial list of segments. + Args: + segments (List[AsrSegment]): The initial list of segments. Returns: - - List[AsrSegment]: A new list of merged segments. + merged_segments (List[AsrSegment]): A new list of merged segments. """ if not segments: return [] @@ -371,16 +526,19 @@ def merge_consecutive_speaker_segments( # Full Method def asr_postprocessing_for_diarization( - diarized_output: SpeakerDiarizationOutput, transcription: WhisperOutput + diarized_output: SpeakerDiarizationOutput, + transcription: WhisperOutput, + merge: bool = False, ) -> WhisperOutput: - """Perform diarized transcription based on individual deployments. + """Perform diarized transcription by combining outputs from individual deployments. - Parameters: - - diarized_output (SpeakerDiarizationOutput): Contains speaker diarization segments. - - transcription (WhisperOutput): Transcription data with segments, text, and language_info. + Args: + diarized_output (SpeakerDiarizationOutput): Contains speaker diarization segments. + transcription (WhisperOutput): Transcription data with segments, text, and language_info. + merge (bool): Whether to merge the same speaker segments in the end. Returns: - - transcription (WhisperOutput): Updated transcription with diarized information per segment/word. + transcription (WhisperOutput): Updated transcription with speaker information per segment/word. """ # 1. Assign speaker labels to each segment and each word in WhisperOutput based on SpeakerDiarizationOutput. @@ -389,12 +547,26 @@ def asr_postprocessing_for_diarization( diarized_output, transcription ) # 2. Aligns the speakers with the punctuations: + word_speaker_mapping = align_with_punctuation(speaker_labelled_transcription) # 3. Create ASR segments by combining the AsrWord with speaker information - # and optionally combine segments based on speaker info. - updated_transcription = create_speaker_segments(word_speaker_mapping) - transcription["segments"] = updated_transcription + + # a. Create speaker segments from new word_speaker_mapping + # b. Limits its length (default 50 words) + + # a & b + segments = create_speaker_segments(word_speaker_mapping) + + # c. Assign new confidence and no_speech_confidence to new segments + + segments = add_segment_variables(segments, transcription) + + # Optional: Merge consecutive speaker segments + if merge: + segments = merge_consecutive_speaker_segments(segments) + + transcription["segments"] = segments return transcription @@ -407,11 +579,11 @@ def combine_homogeneous_speaker_segs( ) -> SpeakerDiarizationOutput: """Combines segments with the same speaker into homogeneous speaker segments, ensuring no overlapping times. - Parameters: - - diarized_output (SpeakerDiarizationOutput): Input with segments that may have overlapping times. + Args: + diarized_output (SpeakerDiarizationOutput): Input with segments that may have overlapping times. Returns: - - SpeakerDiarizationOutput: Output with combined homogeneous speaker segments. + SpeakerDiarizationOutput: Output with combined homogeneous speaker segments. """ combined_segments: list = [] current_speaker = None diff --git a/aana/tests/deployments/test_pyannote_speaker_diarization_deployment.py b/aana/tests/deployments/test_pyannote_speaker_diarization_deployment.py index 97a0391f..cec00b70 100644 --- a/aana/tests/deployments/test_pyannote_speaker_diarization_deployment.py +++ b/aana/tests/deployments/test_pyannote_speaker_diarization_deployment.py @@ -1,7 +1,7 @@ # ruff: noqa: S101 from importlib import resources from pathlib import Path -import json + import pytest from aana.core.models.audio import Audio @@ -55,7 +55,5 @@ async def test_speaker_diarization(self, setup_deployment, audio_file): output = await handle.diarize(audio=audio) output = pydantic_to_dict(output) - # save or print this, evenif it is different. - # with open(expected_output_path, "w") as json_file: - # json.dump(output, json_file, indent=4) + verify_deployment_results(expected_output_path, output) diff --git a/aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_diar.json b/aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_diar.json index baf3775d..0cdab306 100644 --- a/aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_diar.json +++ b/aana/tests/files/expected/whisper/whisper_medium_sd_sample.wav_diar.json @@ -6,8 +6,8 @@ "start": 6.38, "end": 7.84 }, - "confidence": null, - "no_speech_confidence": null, + "confidence": 0.8329984157521475, + "no_speech_confidence": 0.012033582665026188, "words": [ { "word": " Hello.", @@ -36,8 +36,8 @@ "start": 8.3, "end": 9.68 }, - "confidence": null, - "no_speech_confidence": null, + "confidence": 0.8329984157521475, + "no_speech_confidence": 0.012033582665026188, "words": [ { "word": " Oh,", @@ -120,8 +120,8 @@ "start": 9.98, "end": 10.64 }, - "confidence": null, - "no_speech_confidence": null, + "confidence": 0.8329984157521475, + "no_speech_confidence": 0.012033582665026188, "words": [ { "word": " Neither", @@ -159,8 +159,8 @@ "start": 10.64, "end": 13.94 }, - "confidence": null, - "no_speech_confidence": null, + "confidence": 0.8329984157521475, + "no_speech_confidence": 0.012033582665026188, "words": [ { "word": " Okay,", @@ -306,8 +306,8 @@ "start": 14.36, "end": 17.42 }, - "confidence": null, - "no_speech_confidence": null, + "confidence": 0.8329984157521475, + "no_speech_confidence": 0.012033582665026188, "words": [ { "word": " And", @@ -390,8 +390,8 @@ "start": 18.04, "end": 21.34 }, - "confidence": null, - "no_speech_confidence": null, + "confidence": 0.8329984157521475, + "no_speech_confidence": 0.012033582665026188, "words": [ { "word": " Oh,", @@ -510,8 +510,8 @@ "start": 21.74, "end": 28.32 }, - "confidence": null, - "no_speech_confidence": null, + "confidence": 0.8329984157521475, + "no_speech_confidence": 0.012033582665026188, "words": [ { "word": " Well,", @@ -729,8 +729,8 @@ "start": 28.38, "end": 29.82 }, - "confidence": null, - "no_speech_confidence": null, + "confidence": 0.8329984157521475, + "no_speech_confidence": 0.012033582665026188, "words": [ { "word": " Oh,", diff --git a/aana/tests/units/test_speaker.py b/aana/tests/units/test_speaker.py index eb72a3d4..1068074c 100644 --- a/aana/tests/units/test_speaker.py +++ b/aana/tests/units/test_speaker.py @@ -11,6 +11,7 @@ AsrTranscription, AsrTranscriptionInfo, ) +from aana.core.models.base import pydantic_to_dict from aana.core.models.speaker import SpeakerDiarizationSegment from aana.processors.speaker import ( SpeakerDiarizationOutput, @@ -23,7 +24,7 @@ @pytest.mark.parametrize("audio_file", ["sd_sample.wav"]) def test_asr_diarization_post_process(audio_file: Literal["sd_sample.wav"]): """Test that the ASR output can be processed to generate diarized transcription.""" - # load precomputed asr and diarization outputs + # load precomputed ASR and Diarization outputs asr_path = ( resources.files("aana.tests.files.expected.whisper") / f"whisper_medium_{audio_file}.json" @@ -36,8 +37,8 @@ def test_asr_diarization_post_process(audio_file: Literal["sd_sample.wav"]): resources.files("aana.tests.files.expected.whisper") / f"whisper_medium_{audio_file}_diar.json" ) - # convert to WhisperOutput and SpeakerDiarizationOutput + # convert to WhisperOutput and SpeakerDiarizationOutput with Path.open(asr_path, "r") as json_file: asr_op = json.load(json_file) @@ -63,5 +64,6 @@ def test_asr_diarization_post_process(audio_file: Literal["sd_sample.wav"]): diar_output, asr_output, ) + processed_transcription = pydantic_to_dict(processed_transcription) verify_deployment_results(expected_results_path, processed_transcription) From a5944e73a86384874aac07ebb67daa680b146528 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Thu, 19 Sep 2024 11:39:36 +0000 Subject: [PATCH 05/18] remove Union dependency and redundant comments --- aana/processors/speaker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aana/processors/speaker.py b/aana/processors/speaker.py index b90ddebd..8e8af056 100644 --- a/aana/processors/speaker.py +++ b/aana/processors/speaker.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import TypedDict, Union +from typing import TypedDict from aana.core.models.asr import ( AsrSegment, @@ -108,8 +108,6 @@ def get_speaker_for_interval( overlap_duration = max(0.0, overlap_end - overlap_start) if overlap_duration > 0 or fill_nearest: - # Calculate union duration for potential future use - # union_duration = max(end_time, interval_end) - min(start_time, interval_start) distance = float( min(abs(start_time - interval_end), abs(end_time - interval_start)) ) From 9f8eb8cbba404ffeecee24b02c0e029dd863027e Mon Sep 17 00:00:00 2001 From: jiltseb Date: Fri, 20 Sep 2024 07:49:14 +0000 Subject: [PATCH 06/18] added description in docs --- docs/pages/model_hub/asr.md | 9 ++++++++- docs/reference/processors.md | 3 ++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/pages/model_hub/asr.md b/docs/pages/model_hub/asr.md index f03534f5..3d811b7c 100644 --- a/docs/pages/model_hub/asr.md +++ b/docs/pages/model_hub/asr.md @@ -49,4 +49,11 @@ Here are some other possible configurations for the Whisper deployment: compute_type=WhisperComputeType.FLOAT32, ).model_dump(mode="json"), ) - ``` \ No newline at end of file + ``` + +### Diarized ASR + +Diarized transcription can be generated by using [WhisperDeployment](./../../reference/deployments.md#aana.deployments.WhisperDeployment) and [PyannoteSpeakerDiarizationDeployment](./../../reference/deployments.md#aana.deployments.PyannoteSpeakerDiarizationDeployment) by combining the timelines using post processing with [speaker information](./../../reference/processors.md). + +You can simply define the model deployments and the endpoint to transcribe the video with diarization. Register these deployments and start the application. An example notebook on diarized transcription is available [here](./../../../notebooks/diarized_transcription_example.ipynb). + diff --git a/docs/reference/processors.md b/docs/reference/processors.md index 1b463ef8..a656a6bb 100644 --- a/docs/reference/processors.md +++ b/docs/reference/processors.md @@ -2,4 +2,5 @@ ::: aana.processors.remote ::: aana.processors.video -::: aana.processors.batch \ No newline at end of file +::: aana.processors.batch +::: aana.processors.speaker \ No newline at end of file From 835501508cb36c3cd8b11bfa09d78244b176b123 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Fri, 20 Sep 2024 08:09:28 +0000 Subject: [PATCH 07/18] fix typo in docs --- docs/pages/model_hub/asr.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/pages/model_hub/asr.md b/docs/pages/model_hub/asr.md index 3d811b7c..3d658aeb 100644 --- a/docs/pages/model_hub/asr.md +++ b/docs/pages/model_hub/asr.md @@ -53,7 +53,7 @@ Here are some other possible configurations for the Whisper deployment: ### Diarized ASR -Diarized transcription can be generated by using [WhisperDeployment](./../../reference/deployments.md#aana.deployments.WhisperDeployment) and [PyannoteSpeakerDiarizationDeployment](./../../reference/deployments.md#aana.deployments.PyannoteSpeakerDiarizationDeployment) by combining the timelines using post processing with [speaker information](./../../reference/processors.md). +Diarized transcription can be generated by using [WhisperDeployment](./../../reference/deployments.md#aana.deployments.WhisperDeployment) and [PyannoteSpeakerDiarizationDeployment](./../../reference/deployments.md#aana.deployments.PyannoteSpeakerDiarizationDeployment) and combining the timelines using post processing with [speaker information](./../../reference/processors.md). -You can simply define the model deployments and the endpoint to transcribe the video with diarization. Register these deployments and start the application. An example notebook on diarized transcription is available [here](./../../../notebooks/diarized_transcription_example.ipynb). +You can simply define the model deployments and the endpoint to transcribe the video with diarization. Register these deployments and start the application. An example notebook on diarized transcription is available at `notebooks/diarized_transcription_example.ipynb`. From f3f0e1d8673b4151b5bbac05e321a20a51401176 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Mon, 23 Sep 2024 12:36:16 +0000 Subject: [PATCH 08/18] Add default word speaker to None --- aana/core/models/asr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aana/core/models/asr.py b/aana/core/models/asr.py index 0190f37f..36ed4513 100644 --- a/aana/core/models/asr.py +++ b/aana/core/models/asr.py @@ -37,7 +37,7 @@ class AsrWord(BaseModel): """ word: str = Field(description="The word text") - speaker: str | None = Field(description="Speaker label for the word") + speaker: str | None = Field(None, description="Speaker label for the word") time_interval: TimeInterval = Field(description="Time interval of the word") alignment_confidence: float = Field( ge=0.0, le=1.0, description="Alignment confidence of the word" From 2ae3e83026850aa49bf265510ff80e1be2831334 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Mon, 23 Sep 2024 12:38:45 +0000 Subject: [PATCH 09/18] added GatedRepoError, uses segments for combining speaker segments --- ...pyannote_speaker_diarization_deployment.py | 24 ++++++++----------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/aana/deployments/pyannote_speaker_diarization_deployment.py b/aana/deployments/pyannote_speaker_diarization_deployment.py index 0fbdc9e7..17efc322 100644 --- a/aana/deployments/pyannote_speaker_diarization_deployment.py +++ b/aana/deployments/pyannote_speaker_diarization_deployment.py @@ -1,7 +1,7 @@ -import logging from typing import Any, TypedDict import torch +from huggingface_hub.utils import GatedRepoError from pyannote.audio import Pipeline from pyannote.core import Annotation from pydantic import BaseModel, ConfigDict, Field @@ -16,7 +16,7 @@ from aana.core.models.time import TimeInterval from aana.deployments.base_deployment import BaseDeployment from aana.exceptions.runtime import InferenceException -from aana.processors.speaker import combine_homogeneous_speaker_segs +from aana.processors.speaker import combine_homogeneous_speaker_diarization_segments class SpeakerDiarizationOutput(TypedDict): @@ -75,17 +75,11 @@ async def apply_config(self, config: dict[str, Any]): # load model using pyannote Pipeline self.diarize_model = Pipeline.from_pretrained(self.model_id) - # Check if model, log error if None. - if self.diarize_model is None: - logging.error( - f"Accept user agreements at huggingface for the model: {self.model_id}" - ) - - else: + if self.diarize_model: self.diarize_model.to(torch.device(self.device)) except Exception as e: - raise InferenceException(self.model_id) from e + raise GatedRepoError from e async def __inference( self, audio: Audio, params: PyannoteSpeakerDiarizationParams @@ -148,8 +142,10 @@ async def diarize( ) ) - # Perform post-processing to combine homogeneous speaker segments. - processed_speaker_diarization_segments = combine_homogeneous_speaker_segs( - SpeakerDiarizationOutput(segments=speaker_diarization_segments) + # Combine homogeneous speaker segments. + processed_speaker_diarization_segments = ( + combine_homogeneous_speaker_diarization_segments( + speaker_diarization_segments + ) ) - return processed_speaker_diarization_segments + return SpeakerDiarizationOutput(segments=processed_speaker_diarization_segments) From a0776d7713bf2ddf40a8f58db13e607123618091 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Mon, 23 Sep 2024 12:42:01 +0000 Subject: [PATCH 10/18] code snippet to illustrate combining ASR and diarization outputs, link to speaker_recognition model hub --- docs/pages/model_hub/asr.md | 61 +++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/docs/pages/model_hub/asr.md b/docs/pages/model_hub/asr.md index 3d658aeb..a8d1f513 100644 --- a/docs/pages/model_hub/asr.md +++ b/docs/pages/model_hub/asr.md @@ -53,7 +53,64 @@ Here are some other possible configurations for the Whisper deployment: ### Diarized ASR -Diarized transcription can be generated by using [WhisperDeployment](./../../reference/deployments.md#aana.deployments.WhisperDeployment) and [PyannoteSpeakerDiarizationDeployment](./../../reference/deployments.md#aana.deployments.PyannoteSpeakerDiarizationDeployment) and combining the timelines using post processing with [speaker information](./../../reference/processors.md). +Diarized transcription can be generated by using [WhisperDeployment](./../../reference/deployments.md#aana.deployments.WhisperDeployment) and [PyannoteSpeakerDiarizationDeployment](./../../reference/deployments.md#aana.deployments.PyannoteSpeakerDiarizationDeployment) and combining the timelines using post processing with [ASRPostProcessingForDiarization](./../../reference/processors.md). -You can simply define the model deployments and the endpoint to transcribe the video with diarization. Register these deployments and start the application. An example notebook on diarized transcription is available at `notebooks/diarized_transcription_example.ipynb`. +Example configuration for the PyannoteSpeakerDiarization model is available at [Speaker Diarization](./speaker_recognition.md) model hub. + +You can simply define the model deployments and the endpoint to transcribe the video with diarization. Below code snippet shows the custom endpoint class `TranscribeVideoWithDiarEndpoint` to combine the outputs from ASR and diarization deployments: + + ```python + from aana.api.api_generation import Endpoint + from aana.core.models.speaker import PyannoteSpeakerDiarizationParams + from aana.core.models.video import VideoInput + from aana.core.models.whisper import WhisperParams + from aana.deployments.whisper_deployment import WhisperOutput + + from aana.deployments.aana_deployment_handle import AanaDeploymentHandle + + from aana.integrations.external.yt_dlp import download_video + from aana.processors.remote import run_remote + from aana.processors.speaker import ASRPostProcessingForDiarization + from aana.processors.video import extract_audio + + class TranscribeVideoWithDiarEndpoint(Endpoint): + """Transcribe video with diarization endpoint.""" + + async def initialize(self): + """Initialize the endpoint.""" + self.asr_handle = await AanaDeploymentHandle.create("asr_deployment") + self.diar_handle = await AanaDeploymentHandle.create("diarization_deployment") + await super().initialize() + + async def run( + self, + video: VideoInput, + whisper_params: WhisperParams, + diar_params: PyannoteSpeakerDiarizationParams, + ) -> WhisperOutput: + """Transcribe video with diarization.""" + video_obj = await run_remote(download_video)(video_input=video) + audio = extract_audio(video=video_obj) + + # diarized transcript requires word_timestamps from ASR + whisper_params.word_timestamps = True + transcription = await self.asr_handle.transcribe( + audio=audio, params=whisper_params + ) + diarized_output = await self.diar_handle.diarize( + audio=audio, params=diar_params + ) + post_processor = ASRPostProcessingForDiarization( + diarized_segments=diarized_output["segments"], + transcription_segments=transcription["segments"], + ) + updated_segments = post_processor.process() + output_segments = [ + s.model_dump(include=["text", "time_interval", "speaker"]) + for s in updated_segments + ] + + return {"segments": output_segments} + ``` +An example notebook on diarized transcription is available at `notebooks/diarized_transcription_example.ipynb`. From 1e14419788edf3e1e29dc1c49aa7b2dc84f5e652 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Mon, 23 Sep 2024 12:44:25 +0000 Subject: [PATCH 11/18] removes usage of WhisperOutput and SpeakerDiarizationOutput classes, update post processor --- aana/tests/units/test_speaker.py | 43 ++++++++++---------------------- 1 file changed, 13 insertions(+), 30 deletions(-) diff --git a/aana/tests/units/test_speaker.py b/aana/tests/units/test_speaker.py index 1068074c..bd84f684 100644 --- a/aana/tests/units/test_speaker.py +++ b/aana/tests/units/test_speaker.py @@ -6,18 +6,9 @@ import pytest -from aana.core.models.asr import ( - AsrSegment, - AsrTranscription, - AsrTranscriptionInfo, -) -from aana.core.models.base import pydantic_to_dict +from aana.core.models.asr import AsrSegment from aana.core.models.speaker import SpeakerDiarizationSegment -from aana.processors.speaker import ( - SpeakerDiarizationOutput, - WhisperOutput, - asr_postprocessing_for_diarization, -) +from aana.processors.speaker import ASRPostProcessingForDiarization from aana.tests.utils import verify_deployment_results @@ -42,28 +33,20 @@ def test_asr_diarization_post_process(audio_file: Literal["sd_sample.wav"]): with Path.open(asr_path, "r") as json_file: asr_op = json.load(json_file) - asr_output = WhisperOutput( - segments=[AsrSegment.model_validate(segment) for segment in asr_op["segments"]], - transcription_info=AsrTranscriptionInfo.model_validate( - asr_op["transcription_info"] - ), - transcription=AsrTranscription.model_validate(asr_op["transcription"]), - ) + asr_segments = [ + AsrSegment.model_validate(segment) for segment in asr_op["segments"] + ] with Path.open(diar_path, "r") as json_file: diar_op = json.load(json_file) - diar_output = SpeakerDiarizationOutput( - segments=[ - SpeakerDiarizationSegment.model_validate(segment) - for segment in diar_op["segments"] - ], - ) - - processed_transcription = asr_postprocessing_for_diarization( - diar_output, - asr_output, + diarized_segments = [ + SpeakerDiarizationSegment.model_validate(segment) + for segment in diar_op["segments"] + ] + post_processor = ASRPostProcessingForDiarization( + diarized_segments=diarized_segments, transcription_segments=asr_segments ) - processed_transcription = pydantic_to_dict(processed_transcription) + asr_op["segments"] = post_processor.process() - verify_deployment_results(expected_results_path, processed_transcription) + verify_deployment_results(expected_results_path, asr_op) From 0ffbc19f9edd1b1e27429dec6a14e248fcd84b6c Mon Sep 17 00:00:00 2001 From: jiltseb Date: Mon, 23 Sep 2024 12:45:42 +0000 Subject: [PATCH 12/18] added linkt to asr model hub for diarized ASR implementation details --- docs/pages/model_hub/speaker_recognition.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/docs/pages/model_hub/speaker_recognition.md b/docs/pages/model_hub/speaker_recognition.md index 0d34d444..5aa053e4 100644 --- a/docs/pages/model_hub/speaker_recognition.md +++ b/docs/pages/model_hub/speaker_recognition.md @@ -35,7 +35,7 @@ The PyAnnote speaker diarization models are gated, requiring special access. To To get your Hugging Face access token, visit the [Hugging Face Settings - Tokens](https://huggingface.co/settings/tokens). -## Example Configurations +### Example Configurations As an example, let's see how to configure the Pyannote Speaker Diarization deployment for the [Speaker Diarization-3.1 model](https://huggingface.co/pyannote/speaker-diarization-3.1). @@ -53,4 +53,9 @@ As an example, let's see how to configure the Pyannote Speaker Diarization deplo sample_rate=16000, ).model_dump(mode="json"), ) - ``` \ No newline at end of file + ``` + +### Diarized ASR + +Speaker Diarization output can be combined with ASR to generate transcription with speaker information. Further details and code snippet are available in [ASR model hub](./asr.md). + From 42d76198372f9760ccd27fed34670a35f147f855 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Mon, 23 Sep 2024 12:50:05 +0000 Subject: [PATCH 13/18] added post processing class, minor changes --- .../diarized_transcription_example.ipynb | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/notebooks/diarized_transcription_example.ipynb b/notebooks/diarized_transcription_example.ipynb index e5d5a5be..5d0114eb 100644 --- a/notebooks/diarized_transcription_example.ipynb +++ b/notebooks/diarized_transcription_example.ipynb @@ -18,13 +18,13 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", - "os.environ[\"HF_TOKEN\"] = \"YOUR_HF_TOKEN_GOES_HERE\"\n", + "os.environ[\"HF_TOKEN\"] = \"\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" ] }, @@ -37,13 +37,15 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ + "2024-09-23 11:22:15,770\tWARNING services.py:2017 -- WARNING: The object store is using /tmp instead of /dev/shm because /dev/shm has only 67108864 bytes available. This will harm performance! You may be able to free up space by deleting files in /dev/shm. If you are inside a Docker container, you can increase /dev/shm size by passing '--shm-size=10.24gb' to 'docker run' (or add it to the run_options list in a Ray cluster config). Make sure to set this to more than 30% of available RAM.\n", + "2024-09-23 11:22:16,932\tINFO worker.py:1774 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n", "INFO [alembic.runtime.migration] Context impl SQLiteImpl.\n", "INFO [alembic.runtime.migration] Will assume non-transactional DDL.\n" ] @@ -51,7 +53,6 @@ ], "source": [ "from aana.api.api_generation import Endpoint\n", - "from aana.core.models.base import pydantic_to_dict\n", "from aana.core.models.speaker import PyannoteSpeakerDiarizationParams\n", "from aana.core.models.video import VideoInput\n", "from aana.core.models.whisper import WhisperParams\n", @@ -69,7 +70,7 @@ ")\n", "from aana.integrations.external.yt_dlp import download_video\n", "from aana.processors.remote import run_remote\n", - "from aana.processors.speaker import asr_postprocessing_for_diarization\n", + "from aana.processors.speaker import ASRPostProcessingForDiarization\n", "from aana.processors.video import extract_audio\n", "from aana.sdk import AanaSDK\n", "\n", @@ -80,7 +81,7 @@ " \"num_gpus\": 0.25\n", " }, # Remove this line if you want to run Whisper on a CPU.# Also change type to float32.\n", " user_config=WhisperConfig(\n", - " model_size=WhisperModelSize.SMALL,\n", + " model_size=WhisperModelSize.MEDIUM,\n", " compute_type=WhisperComputeType.FLOAT16,\n", " ).model_dump(mode=\"json\"),\n", ")\n", @@ -127,18 +128,17 @@ " diarized_output = await self.diar_handle.diarize(\n", " audio=audio, params=diar_params\n", " )\n", - " transcription = asr_postprocessing_for_diarization(\n", - " diarized_output, transcription\n", + " post_processor = ASRPostProcessingForDiarization(\n", + " diarized_segments=diarized_output[\"segments\"],\n", + " transcription_segments=transcription[\"segments\"],\n", " )\n", - " output = pydantic_to_dict(transcription)\n", - "\n", - " output_keys = [\"time_interval\", \"speaker\", \"text\"]\n", - " filtered_data = [\n", - " {k: v for k, v in entry.items() if k in output_keys}\n", - " for entry in output[\"segments\"]\n", + " updated_segments = post_processor.process()\n", + " output_segments = [\n", + " s.model_dump(include=[\"text\", \"time_interval\", \"speaker\"])\n", + " for s in updated_segments\n", " ]\n", "\n", - " return {\"segments\": filtered_data}\n", + " return {\"segments\": output_segments}\n", "\n", "\n", "endpoints = [\n", @@ -173,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -196,7 +196,7 @@ "\n" ], "text/plain": [ - "Documentation is available at \u001b]8;id=337020;http://127.0.0.1:8000/docs\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/docs\u001b[0m\u001b]8;;\u001b\\ and \u001b]8;id=418863;http://127.0.0.1:8000/redoc\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/redoc\u001b[0m\u001b]8;;\u001b\\\n" + "Documentation is available at \u001b]8;id=363073;http://127.0.0.1:8000/docs\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/docs\u001b[0m\u001b]8;;\u001b\\ and \u001b]8;id=935227;http://127.0.0.1:8000/redoc\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/redoc\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, @@ -216,24 +216,25 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{'segments': [{'text': ' Hello. Hello.', 'time_interval': {'start': 6.38, 'end': 7.84}, 'speaker': 'SPEAKER_01'}, {'text': \" Oh, hello. I didn't know you were there.\", 'time_interval': {'start': 8.3, 'end': 9.68}, 'speaker': 'SPEAKER_02'}, {'text': ' Neither did I.', 'time_interval': {'start': 9.98, 'end': 10.64}, 'speaker': 'SPEAKER_01'}, {'text': ' Okay, I thought, you know, I heard it deep. This is Diane in New Jersey.', 'time_interval': {'start': 10.64, 'end': 13.94}, 'speaker': 'SPEAKER_02'}, {'text': \" And I'm Sheila in Texas, originally from Chicago.\", 'time_interval': {'start': 14.36, 'end': 17.42}, 'speaker': 'SPEAKER_00'}, {'text': \" Oh, I'm originally from Chicago also. I'm in New Jersey now, though.\", 'time_interval': {'start': 18.04, 'end': 21.34}, 'speaker': 'SPEAKER_02'}, {'text': ' Well, there is not much difference. At least, you know, they all call me a Yankee down here, so what can I say?', 'time_interval': {'start': 21.74, 'end': 28.32}, 'speaker': 'SPEAKER_00'}, {'text': \" Oh, I don't hear that in New Jersey now.\", 'time_interval': {'start': 28.38, 'end': 29.82}, 'speaker': 'SPEAKER_02'}]}\n" + "{'segments': [{'text': ' Hello? Hello.', 'time_interval': {'start': 6.9, 'end': 8.14}, 'speaker': 'SPEAKER_01'}, {'text': \" Oh, hello. I didn't know you were there.\", 'time_interval': {'start': 8.4, 'end': 9.9}, 'speaker': 'SPEAKER_02'}, {'text': ' Neither did I.', 'time_interval': {'start': 10.22, 'end': 10.88}, 'speaker': 'SPEAKER_01'}, {'text': ' Okay. I thought, you know, I heard a beep. This is Diane in New Jersey.', 'time_interval': {'start': 10.9, 'end': 14.16}, 'speaker': 'SPEAKER_02'}, {'text': \" And I'm Sheila in Texas, originally from Chicago.\", 'time_interval': {'start': 14.4, 'end': 17.74}, 'speaker': 'SPEAKER_00'}, {'text': \" Oh, I'm originally from Chicago also. I'm in New Jersey now, though.\", 'time_interval': {'start': 18.16, 'end': 21.48}, 'speaker': 'SPEAKER_02'}, {'text': \" Well, there isn't that much difference. At least, you know, they all call me a Yankee down here, so what can I say?\", 'time_interval': {'start': 21.9, 'end': 28.38}, 'speaker': 'SPEAKER_00'}, {'text': \" Oh, I don't hear that in New Jersey now.\", 'time_interval': {'start': 28.4, 'end': 29.88}, 'speaker': 'SPEAKER_02'}]}\n" ] } ], "source": [ "import json\n", - "\n", "import requests\n", "\n", + "\n", "video = {\n", - " \"path\": \"../aana/tests/files/audios/sd_sample.wav\", # Video URL, Aana SDK supports URLs (including YouTube), file paths or even raw video data\n", + " # Video URL/path, Aana SDK supports URLs (including YouTube), file paths or even raw video data\n", + " \"path\": \"../aana/tests/files/audios/sd_sample.wav\",\n", " \"media_id\": \"sd_sample\", # Media ID, so we can ask questions about the video later by using this ID\n", "}\n", "\n", From 8b62bedf598e1c2050fb96dc80d9fd6300430e77 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Mon, 23 Sep 2024 13:46:46 +0000 Subject: [PATCH 14/18] code refactoring for diarized ASR post processing --- aana/processors/speaker.py | 1050 ++++++++++++++++++------------------ 1 file changed, 519 insertions(+), 531 deletions(-) diff --git a/aana/processors/speaker.py b/aana/processors/speaker.py index 8e8af056..140e0dff 100644 --- a/aana/processors/speaker.py +++ b/aana/processors/speaker.py @@ -1,595 +1,583 @@ from collections import defaultdict -from typing import TypedDict from aana.core.models.asr import ( AsrSegment, - AsrTranscription, - AsrTranscriptionInfo, AsrWord, ) from aana.core.models.speaker import SpeakerDiarizationSegment from aana.core.models.time import TimeInterval -# Utility functions for speaker-related processing in audio - - -# Redefine SpeakerDiarizationOutput and WhisperOutput to prevent circular imports -class SpeakerDiarizationOutput(TypedDict): - """The output of the Speaker Diarization model. - - Attributes: - segments (list[SpeakerDiarizationSegment]): The Speaker Diarization segments. - """ - - segments: list[SpeakerDiarizationSegment] - - -class WhisperOutput(TypedDict): - """The output of the whisper model. - - Attributes: - segments (list[AsrSegment]): The ASR segments. - transcription_info (AsrTranscriptionInfo): The ASR transcription info. - transcription (AsrTranscription): The ASR transcription. - """ - - segments: list[AsrSegment] - transcription_info: AsrTranscriptionInfo - transcription: AsrTranscription - - # Define sentence ending punctuations to split segments at sentence endings: sentence_ending_punctuations = ".?!" -def assign_word_speakers( - diarized_output: SpeakerDiarizationOutput, - transcription: WhisperOutput, - fill_nearest: bool = False, -) -> WhisperOutput: - """Assigns speaker labels to each segment and word in the transcription based on diarized output. - - Args: - diarized_output (SpeakerDiarizationOutput): Contains speaker diarization segments. - transcription (WhisperOutput): Transcription data with segments, text, and language_info. - fill_nearest (bool): If True, assigns the closest speaker even if there's no positive overlap. Default is False. - - Returns: - transcription (WhisperOutput): Transcription updated in-place with the assigned speaker labels. - """ - for segment in transcription["segments"]: - # Assign speaker to segment - segment.speaker = get_speaker_for_interval( - diarized_output["segments"], - segment.time_interval.start, - segment.time_interval.end, - fill_nearest, - ) - - # Assign speakers to words within the segment - if segment.words: - for word in segment.words: - word.speaker = get_speaker_for_interval( - diarized_output["segments"], - word.time_interval.start, - word.time_interval.end, - fill_nearest, - ) - - return transcription +class ASRPostProcessingForDiarization: + """Class to handle post-processing for diarized ASR output by combining diarization and transcription segments. + The post-processing involves assigning speaker labels to transcription segments and words, aligning speakers + with punctuation, optionally merging homogeneous speaker segments, and reassigning confidence information to the segments. -def get_speaker_for_interval( - sd_segments: list[SpeakerDiarizationSegment], - start_time: float, - end_time: float, - fill_nearest: bool, -) -> str | None: - """Determines the speaker for a given time interval based on diarized segments. - - Args: - sd_segments (list[SpeakerDiarizationSegment]): List of speaker diarization segments. - start_time (float): Start time of the interval. - end_time (float): End time of the interval. - fill_nearest (bool): If True, selects the closest speaker even with no overlap. - - Returns: - str | None: The identified speaker label, or None if no speaker is found. + Attributes: + diarized_segments (list[SpeakerDiarizationSegment]): Contains speaker diarization segments. + transcription_segments (list[AsrSegment]): Transcription segments. + merge (bool): Whether to merge the same speaker segments in the final output. """ - overlaps = [] - - for sd_segment in sd_segments: - interval_start = sd_segment.time_interval.start - interval_end = sd_segment.time_interval.end - - # Calculate overlap duration - overlap_start = max(start_time, interval_start) - overlap_end = min(end_time, interval_end) - overlap_duration = max(0.0, overlap_end - overlap_start) - - if overlap_duration > 0 or fill_nearest: - distance = float( - min(abs(start_time - interval_end), abs(end_time - interval_start)) - ) - overlaps.append( - { - "speaker": sd_segment.speaker, - "overlap_duration": overlap_duration, - "distance": distance, - } - ) - - if not overlaps: - return None - else: - # Select the speaker with the maximum overlap duration (or minimal distance) - best_match = max( - overlaps, - key=lambda x: (x["overlap_duration"], -x["distance"]) - if not fill_nearest - else (-x["distance"]), - ) - return best_match["speaker"] - - -def get_first_word_idx_of_sentence( - word_idx: int, - word_list: list[str], - speaker_list: list[str | None], - max_words: int, -) -> int: - """Get the index of the first word of the sentence in the given range.""" - left_idx = word_idx - while ( - left_idx > 0 - and word_idx - left_idx < max_words - and speaker_list[left_idx - 1] == speaker_list[left_idx] - and word_list[left_idx - 1][-1] not in sentence_ending_punctuations - ): - left_idx -= 1 - - return ( - left_idx - if left_idx == 0 or word_list[left_idx - 1][-1] in sentence_ending_punctuations - else -1 - ) - - -def get_last_word_idx_of_sentence( - word_idx: int, word_list: list[str], max_words: int -) -> int: - """Get the index of the last word of the sentence in the given range.""" - right_idx = word_idx - while ( - right_idx < len(word_list) - and right_idx - word_idx < max_words - and word_list[right_idx][-1] not in sentence_ending_punctuations + def __init__( + self, + diarized_segments: list[SpeakerDiarizationSegment], + transcription_segments: list[AsrSegment], + merge: bool = False, ): - right_idx += 1 - - return ( - right_idx - if right_idx == len(word_list) - 1 - or word_list[right_idx][-1] in sentence_ending_punctuations - else -1 - ) - - -def find_nearest_speaker( - index: int, word_speaker_mapping: list[AsrWord], reverse: bool = False -) -> str | None: - """Find the nearest speaker label in the word_speaker_mapping either forward or backward. - - Args: - index (int): The index to start searching from. - word_speaker_mapping (list[AsrWord]): List of word-speaker mappings. - reverse (bool): Search backwards if True; forwards if False. Default is False. - - Returns: - str | None: The nearest speaker found or None if not found. - """ - step = -1 if reverse else 1 - for i in range(index, len(word_speaker_mapping) if not reverse else -1, step): - if word_speaker_mapping[i].speaker: - return word_speaker_mapping[i].speaker - return None - - -def align_with_punctuation( - transcription: WhisperOutput, max_words_in_sentence: int = 50 -) -> list[AsrWord]: - """Aligns speaker labels with sentence boundaries defined by punctuation. + """Initializes the ASRPostProcessingForDiarization class with diarization and transcription segments. + + Args: + diarized_segments (list[SpeakerDiarizationSegment]): Contains speaker diarization segments. + transcription_segments (list[AsrSegment]): Transcription segments. + merge (bool): If True, merges consecutive speaker segments in the final output. Defaults to False. + """ + self.diarized_segments = diarized_segments + self.transcription_segments = transcription_segments + self.merge = merge + + # Check if inputs are valid: + for segment in self.transcription_segments: + if segment.text and not segment.words: + raise ValueError("Word-level timestamps are required for diarized ASR.") # noqa: TRY003 + + def process(self) -> list[AsrSegment]: + """Executes the post-processing pipeline that combines diarization and transcription segments. + + This method performs the following steps: + 1. Assign speaker labels to each segment and word in the transcription based on the diarization output. + 2. Align speakers with punctuation. + 3. Create new transcription segments by combining the speaker-labeled words. + 4. Optionally, merge consecutive speaker segments. + 5. Add confidence and no_speech_confidence to the new segments. + + Returns: + list[AsrSegment]: Updated transcription segments with speaker information per segment and word. + """ + # intro: validation checks + if not self.transcription_segments or not self.diarized_segments: + return self.transcription_segments + + # 1. Assign speaker labels to each segment and word + speaker_labelled_transcription = self._assign_word_speakers( + self.diarized_segments, self.transcription_segments + ) - Args: - transcription (WhisperOutput): transcription with speaker information. - max_words_in_sentence (int): Maximum number of words allowed in a sentence. + # 2. Align speakers with punctuation + word_speaker_mapping = self._align_with_punctuation( + speaker_labelled_transcription + ) - Returns: - word_speaker_mapping: (list[AsrWord]): Realigned word-speaker mappings. - """ - new_segments = [segment.words for segment in transcription["segments"]] - word_speaker_mapping = [word for segment in new_segments for word in segment] - words_list = [item.word for item in word_speaker_mapping] - speaker_list = [item.speaker for item in word_speaker_mapping] - - # Fill missing speaker labels by finding the nearest speaker - for i, item in enumerate(word_speaker_mapping): - if item.speaker is None: - item.speaker = find_nearest_speaker(i, word_speaker_mapping, reverse=i > 0) - speaker_list[i] = item.speaker - - # Align speakers with sentence boundaries - k = 0 - while k < len(word_speaker_mapping): - if ( - k < len(word_speaker_mapping) - 1 - and speaker_list[k] != speaker_list[k + 1] - and words_list[k][-1] not in sentence_ending_punctuations - ): - left_idx = get_first_word_idx_of_sentence( - k, words_list, speaker_list, max_words_in_sentence - ) - right_idx = get_last_word_idx_of_sentence( - k, words_list, max_words_in_sentence - (k - left_idx) + # 3. Create new transcription segments with speaker information + segments = self._create_speaker_segments(word_speaker_mapping) + + # 4. Assign confidence variables to the new segments + segments = self._add_segment_variables(segments, self.transcription_segments) + + # Optional: Merge consecutive speaker segments + if self.merge: + segments = self._combine_homeogeneous_speaker_asr_segments(segments) + + return segments + + def _assign_word_speakers( + self, + diarized_segments: list[SpeakerDiarizationSegment], + transcription_segments: list[AsrSegment], + fill_nearest: bool = False, + ) -> list[AsrSegment]: + """Assigns speaker labels to each segment and word in the transcription based on diarized output. + + Args: + diarized_segments (list[SpeakerDiarizationSegment]): Contains speaker diarization segments. + transcription_segments (list[AsrSegment]): Transcription data with segments, text, and language_info. + fill_nearest (bool): If True, assigns the closest speaker even if there's no positive overlap. Default is False. + + Returns: + transcription_segments (list[AsrSegment]): Transcription updated in-place with the assigned speaker labels. + """ + for segment in transcription_segments: + # Assign speaker to segment + segment.speaker = self._get_speaker_for_interval( + diarized_segments, + segment.time_interval.start, + segment.time_interval.end, + fill_nearest, ) - if min(left_idx, right_idx) == -1: - k += 1 - continue - - spk_labels = speaker_list[left_idx : right_idx + 1] - mod_speaker = max(set(spk_labels), key=spk_labels.count) - - if spk_labels.count(mod_speaker) >= len(spk_labels) // 2: - speaker_list[left_idx : right_idx + 1] = [mod_speaker] * ( - right_idx - left_idx + 1 + # Assign speakers to words within the segment + if segment.words: + for word in segment.words: + word.speaker = self._get_speaker_for_interval( + diarized_segments, + word.time_interval.start, + word.time_interval.end, + fill_nearest, + ) + + return transcription_segments + + def _get_speaker_for_interval( + self, + sd_segments: list[SpeakerDiarizationSegment], + start_time: float, + end_time: float, + fill_nearest: bool, + ) -> str | None: + """Determines the speaker for a given time interval based on diarized segments. + + Args: + sd_segments (list[SpeakerDiarizationSegment]): List of speaker diarization segments. + start_time (float): Start time of the interval. + end_time (float): End time of the interval. + fill_nearest (bool): If True, selects the closest speaker even with no overlap. + + Returns: + str | None: The identified speaker label, or None if no speaker is found. + """ + overlaps = [] + + for sd_segment in sd_segments: + interval_start = sd_segment.time_interval.start + interval_end = sd_segment.time_interval.end + + # Calculate overlap duration + overlap_start = max(start_time, interval_start) + overlap_end = min(end_time, interval_end) + overlap_duration = max(0.0, overlap_end - overlap_start) + + if overlap_duration > 0 or fill_nearest: + distance = float( + min(abs(start_time - interval_end), abs(end_time - interval_start)) ) - k = right_idx - - k += 1 - - # Realign the speaker labels in the original word_speaker_mapping - for i, item in enumerate(word_speaker_mapping): - item.speaker = speaker_list[i] - - return word_speaker_mapping - - -def create_new_segment( - word_info: AsrWord, speaker: str | None, is_empty: bool = False -) -> AsrSegment: - """Creates a new segment based on word information. - - Args: - word_info (AsrWord): The word information containing text, timing, etc. - speaker (str | None): The speaker associated with this word. - is_empty (bool): If True, creates an empty segment (for punctuation-only segments). - - Returns: - AsrSegment: A new segment with the provided word information and speaker details. - """ - return AsrSegment( - time_interval=TimeInterval( - start=word_info.time_interval.start - if not is_empty - else word_info.time_interval.end, - end=word_info.time_interval.end, - ), - text=word_info.word if not is_empty else "", - speaker=speaker, - words=[word_info] if not is_empty else [], - confidence=None, - no_speech_confidence=None, - ) - - -def create_speaker_segments( - word_list: list[AsrWord], max_words_per_segment: int = 50 -) -> list[AsrSegment]: - """Creates speaker segments from a list of words with speaker annotations. - - Args: - word_list (List[AsrWord]): A list of words with associated speaker and timing details. - max_words_per_segment (int): The maximum number of words per segment. If the segment exceeds this, - it will be split at previous or next sentence-ending punctuation. - - Returns: - List[AsrSegment]: A list of segments where each segment groups words spoken by the same speaker. - """ - if not word_list: - return [] - - current_speaker = None - current_segment = None - final_segments: list[AsrSegment] = [] - word_count = 0 - - for word_info in word_list: - speaker = word_info.speaker or current_speaker - # Handle speaker change - if speaker != current_speaker: - if current_segment: - final_segments.append(current_segment) - current_segment = create_new_segment(word_info, speaker) - current_speaker = speaker - word_count = 1 - else: - if current_segment: - # Handle word count and punctuation splitting - current_segment, word_count = split_segment_on_length_punctuation( - current_segment, - word_info, - word_count, - max_words_per_segment, - final_segments, + overlaps.append( + { + "speaker": sd_segment.speaker, + "overlap_duration": overlap_duration, + "distance": distance, + } ) - # Add the final segment if it exists - if current_segment and current_segment.words: - final_segments.append(current_segment) - - return final_segments - - -def add_segment_variables( - segments: list[AsrSegment], transcription: WhisperOutput -) -> list[AsrSegment]: - """Adds confidence and no_speech_confidence variables to each segment. - - Args: - segments (List[AsrSegment]): A list of segments to which the confidence values will be added. - transcription (WhisperOutput): The transcription data to help determine segment confidence. + if not overlaps: + return None + else: + # Select the speaker with the maximum overlap duration (or minimal distance) + best_match = max( + overlaps, + key=lambda x: (x["overlap_duration"], -x["distance"]) + if not fill_nearest + else (-x["distance"]), + ) + return best_match["speaker"] + + def _get_first_word_idx_of_sentence( + self, + word_idx: int, + word_list: list[str], + speaker_list: list[str | None], + max_words: int, + ) -> int: + """Get the index of the first word of the sentence in the given range.""" + left_idx = word_idx + while ( + left_idx > 0 + and word_idx - left_idx < max_words + and speaker_list[left_idx - 1] == speaker_list[left_idx] + and word_list[left_idx - 1][-1] not in sentence_ending_punctuations + ): + left_idx -= 1 - Returns: - List[AsrSegment]: Segments with confidence and no_speech_confidence added. - """ - for segment in segments: - confidence, no_speech_confidence = determine_major_segment_confidence( - segment, transcription + return ( + left_idx + if left_idx == 0 + or word_list[left_idx - 1][-1] in sentence_ending_punctuations + else -1 ) - segment.confidence = confidence - segment.no_speech_confidence = no_speech_confidence - return segments - - -def split_segment_on_length_punctuation( - current_segment: AsrSegment, - word_info: AsrWord, - word_count: int, - max_words_per_segment: int, - final_segments: list[AsrSegment], -) -> tuple[AsrSegment, int]: - """Splits segments based on length and sentence-ending punctuation. - Args: - current_segment (AsrSegment): The current speaker segment being processed. - word_info (AsrWord): Word information containing timing and text. - word_count (int): The current word count in the segment. - max_words_per_segment (int): Maximum number of words allowed in a segment before splitting. - final_segments (List[AsrSegment]): List of segments to which the completed segment will be added. + def _get_last_word_idx_of_sentence( + self, word_idx: int, word_list: list[str], max_words: int + ) -> int: + """Get the index of the last word of the sentence in the given range.""" + right_idx = word_idx + while ( + right_idx < len(word_list) + and right_idx - word_idx < max_words + and word_list[right_idx][-1] not in sentence_ending_punctuations + ): + right_idx += 1 - Returns: - Tuple[AsrSegment, int]: The updated segment and word count. - """ - # Check if word count exceeds the limit and if punctuation exists to split - if word_count >= max_words_per_segment and any( - p in word_info.word for p in sentence_ending_punctuations - ): - # update current segment and then append it - current_segment.time_interval.end = word_info.time_interval.end - current_segment.text += f"{word_info.word}" - current_segment.words.append(word_info) - final_segments.append(current_segment) - current_segment = create_new_segment( - word_info, current_segment.speaker, is_empty=True + return ( + right_idx + if right_idx == len(word_list) - 1 + or word_list[right_idx][-1] in sentence_ending_punctuations + else -1 ) - word_count = 0 # Reset word count - - else: - # Append word to the current segment - current_segment.time_interval.end = word_info.time_interval.end - current_segment.text += f"{word_info.word}" - current_segment.words.append(word_info) - word_count += 1 - # If sentence-ending punctuation is found, finalize the segment - # if any(p in word_info.word for p in sentence_ending_punctuations): - # final_segments.append(current_segment) - # current_segment = create_new_segment( - # word_info, current_segment.speaker, is_empty=True - # ) - # word_count = 0 # Reset word count after punctuation - - return current_segment, word_count - - -def determine_major_segment_confidence( - segment: AsrSegment, transcription: WhisperOutput -) -> tuple[float | None, float | None]: - """Determines the confidence and no_speech_confidence based on the major segment (which contributes the most time or words). - - Args: - segment (AsrSegment): New ASR segment. - transcription (WhisperOutput): Original transcription containing segments with confidence. + def _find_nearest_speaker( + self, index: int, word_speaker_mapping: list[AsrWord], reverse: bool = False + ) -> str | None: + """Find the nearest speaker label in the word_speaker_mapping either forward or backward. + + Args: + index (int): The index to start searching from. + word_speaker_mapping (list[AsrWord]): List of word-speaker mappings. + reverse (bool): Search backwards if True; forwards if False. Default is False. + + Returns: + str | None: The nearest speaker found or None if not found. + """ + step = -1 if reverse else 1 + for i in range(index, len(word_speaker_mapping) if not reverse else -1, step): + if word_speaker_mapping[i].speaker: + return word_speaker_mapping[i].speaker + return None - Returns: - tuple[Optional[float], Optional[float]]: Confidence and no_speech_confidence from the major segment. - """ + def _align_with_punctuation( + self, transcription_segments: list[AsrSegment], max_words_in_sentence: int = 50 + ) -> list[AsrWord]: + """Aligns speaker labels with sentence boundaries defined by punctuation. + + Args: + transcription_segments (list[AsrSegment]): transcription segments with speaker information. + max_words_in_sentence (int): Maximum number of words allowed in a sentence. + + Returns: + word_speaker_mapping: (list[AsrWord]): Realigned word-speaker mappings. + """ + new_segments = [segment.words for segment in transcription_segments] + word_speaker_mapping = [word for segment in new_segments for word in segment] + words_list = [item.word for item in word_speaker_mapping] + speaker_list = [item.speaker for item in word_speaker_mapping] + + # Fill missing speaker labels by finding the nearest speaker + for i, item in enumerate(word_speaker_mapping): + if item.speaker is None: + item.speaker = self._find_nearest_speaker( + i, word_speaker_mapping, reverse=i > 0 + ) + speaker_list[i] = item.speaker + + # Align speakers with sentence boundaries + k = 0 + while k < len(word_speaker_mapping): + if ( + k < len(word_speaker_mapping) - 1 + and speaker_list[k] != speaker_list[k + 1] + and words_list[k][-1] not in sentence_ending_punctuations + ): + left_idx = self._get_first_word_idx_of_sentence( + k, words_list, speaker_list, max_words_in_sentence + ) + right_idx = self._get_last_word_idx_of_sentence( + k, words_list, max_words_in_sentence - (k - left_idx) + ) - def find_closest_segment(word_start: float, word_end: float) -> AsrSegment | None: - """Finds the closest segment in the transcription for the given word start and end times.""" - closest_segment = min( - transcription["segments"], - key=lambda segment: abs(segment.time_interval.start - word_start) - + abs(segment.time_interval.end - word_end), - default=None, + if min(left_idx, right_idx) == -1: + k += 1 + continue + + spk_labels = speaker_list[left_idx : right_idx + 1] + mod_speaker = max(set(spk_labels), key=spk_labels.count) + + if spk_labels.count(mod_speaker) >= len(spk_labels) // 2: + speaker_list[left_idx : right_idx + 1] = [mod_speaker] * ( + right_idx - left_idx + 1 + ) + k = right_idx + + k += 1 + + # Realign the speaker labels in the original word_speaker_mapping + for i, item in enumerate(word_speaker_mapping): + item.speaker = speaker_list[i] + + return word_speaker_mapping + + def _create_new_segment( + self, word_info: AsrWord, speaker: str | None, is_empty: bool = False + ) -> AsrSegment: + """Creates a new segment based on word information. + + Args: + word_info (AsrWord): The word information containing text, timing, etc. + speaker (str | None): The speaker associated with this word. + is_empty (bool): If True, creates an empty segment (for punctuation-only segments). + + Returns: + AsrSegment: A new segment with the provided word information and speaker details. + """ + return AsrSegment( + time_interval=TimeInterval( + start=word_info.time_interval.start + if not is_empty + else word_info.time_interval.end, + end=word_info.time_interval.end, + ), + text=word_info.word if not is_empty else "", + speaker=speaker, + words=[word_info] if not is_empty else [], + confidence=None, + no_speech_confidence=None, ) - return closest_segment - - def update_segment_contribution( - contributions: dict, segment: AsrSegment, word_duration: float - ) -> None: - """Updates the contribution data for the given segment.""" - segment_id = id(segment) - if segment_id not in contributions: - contributions[segment_id] = { - "segment": segment, - "contribution_time": 0.0, - "word_count": 0, - } - contributions[segment_id]["contribution_time"] += word_duration - contributions[segment_id]["word_count"] += 1 - - segment_contributions: defaultdict = defaultdict( - lambda: {"segment": None, "contribution_time": 0.0, "word_count": 0} - ) - - for word in segment.words: - word_start, word_end = word.time_interval.start, word.time_interval.end - word_duration = word_end - word_start - - closest_segment = find_closest_segment(word_start, word_end) - - if closest_segment: - update_segment_contribution( - segment_contributions, closest_segment, word_duration - ) - - if not segment_contributions: - return None, None - - # Determine the segment with the highest word count or contribution time - major_segment_data = max( - segment_contributions.values(), - key=lambda data: data[ - "word_count" - ], # Change this to 'contribution_time' if needed - ) - - major_segment = major_segment_data["segment"] - return major_segment.confidence, major_segment.no_speech_confidence - - -def merge_consecutive_speaker_segments( - segments: list[AsrSegment], -) -> list[AsrSegment]: - """Merges consecutive segments that have the same speaker into a single segment. - - Args: - segments (List[AsrSegment]): The initial list of segments. - - Returns: - merged_segments (List[AsrSegment]): A new list of merged segments. - """ - if not segments: - return [] - - merged_segments: list[AsrSegment] = [] - mapping: defaultdict[str, str] = defaultdict(str) - current_segment = segments[0] - speaker_counter = 0 + def _create_speaker_segments( + self, word_list: list[AsrWord], max_words_per_segment: int = 50 + ) -> list[AsrSegment]: + """Creates speaker segments from a list of words with speaker annotations. + + Args: + word_list (list[AsrWord]): A list of words with associated speaker and timing details. + max_words_per_segment (int): The maximum number of words per segment. If the segment exceeds this, + it will be split at previous or next sentence-ending punctuation. + + Returns: + list[AsrSegment]: A list of segments where each segment groups words spoken by the same speaker. + """ + if not word_list: + return [] + + current_speaker = None + current_segment = None + final_segments: list[AsrSegment] = [] + word_count = 0 + + for word_info in word_list: + speaker = word_info.speaker or current_speaker + + # Handle speaker change + if speaker != current_speaker: + if current_segment: + final_segments.append(current_segment) + current_segment = self._create_new_segment(word_info, speaker) + current_speaker = speaker + word_count = 1 + else: + if current_segment: + # Handle word count and punctuation splitting + ( + current_segment, + word_count, + ) = self._split_segment_on_length_punctuation( + current_segment, + word_info, + word_count, + max_words_per_segment, + final_segments, + ) + + # Add the final segment if it exists + if current_segment and current_segment.words: + final_segments.append(current_segment) + + return final_segments + + def _add_segment_variables( + self, segments: list[AsrSegment], transcription_segments: list[AsrSegment] + ) -> list[AsrSegment]: + """Adds confidence and no_speech_confidence variables to each segment. + + Args: + segments (list[AsrSegment]): A list of segments to which the confidence values will be added. + transcription_segments (list[AsrSegment]): The original transcription segments to help determine segment confidence. + + Returns: + list[AsrSegment]: Segments with confidence and no_speech_confidence added. + """ + for segment in segments: + confidence, no_speech_confidence = self._determine_major_segment_confidence( + segment, transcription_segments + ) + segment.confidence = confidence + segment.no_speech_confidence = no_speech_confidence + return segments + + def _split_segment_on_length_punctuation( + self, + current_segment: AsrSegment, + word_info: AsrWord, + word_count: int, + max_words_per_segment: int, + final_segments: list[AsrSegment], + ) -> tuple[AsrSegment, int]: + """Splits segments based on length and sentence-ending punctuation. + + Args: + current_segment (AsrSegment): The current speaker segment being processed. + word_info (AsrWord): Word information containing timing and text. + word_count (int): The current word count in the segment. + max_words_per_segment (int): Maximum number of words allowed in a segment before splitting. + final_segments (list[AsrSegment]): List of segments to which the completed segment will be added. + + Returns: + Tuple[AsrSegment, int]: The updated segment and word count. + """ + # Check if word count exceeds the limit and if punctuation exists to split + if word_count >= max_words_per_segment and any( + p in word_info.word for p in sentence_ending_punctuations + ): + # update current segment and then append it + current_segment.time_interval.end = word_info.time_interval.end + current_segment.text += f"{word_info.word}" + current_segment.words.append(word_info) + final_segments.append(current_segment) + current_segment = self._create_new_segment( + word_info, current_segment.speaker, is_empty=True + ) + word_count = 0 # Reset word count - for next_segment in segments[1:]: - if next_segment.speaker == current_segment.speaker: - # Merge segments - current_segment.time_interval.end = next_segment.time_interval.end - current_segment.text += f" {next_segment.text}" - current_segment.words.extend(next_segment.words) else: - # Assign unique speaker labels and finalize the current segment - if current_segment.speaker: - current_segment.speaker = mapping.setdefault( - current_segment.speaker, f"SPEAKER_{speaker_counter:02d}" - ) - for word in current_segment.words: - word.speaker = current_segment.speaker - merged_segments.append(current_segment) - current_segment = next_segment - speaker_counter += 1 - - # Handle the last segment - if current_segment.speaker: - current_segment.speaker = mapping.setdefault( - current_segment.speaker, f"SPEAKER_{speaker_counter:02d}" - ) - for word in current_segment.words: - word.speaker = current_segment.speaker - merged_segments.append(current_segment) - - return merged_segments - - -# Full Method -def asr_postprocessing_for_diarization( - diarized_output: SpeakerDiarizationOutput, - transcription: WhisperOutput, - merge: bool = False, -) -> WhisperOutput: - """Perform diarized transcription by combining outputs from individual deployments. - - Args: - diarized_output (SpeakerDiarizationOutput): Contains speaker diarization segments. - transcription (WhisperOutput): Transcription data with segments, text, and language_info. - merge (bool): Whether to merge the same speaker segments in the end. - - Returns: - transcription (WhisperOutput): Updated transcription with speaker information per segment/word. - - """ - # 1. Assign speaker labels to each segment and each word in WhisperOutput based on SpeakerDiarizationOutput. - - speaker_labelled_transcription = assign_word_speakers( - diarized_output, transcription - ) - # 2. Aligns the speakers with the punctuations: + # Append word to the current segment + current_segment.time_interval.end = word_info.time_interval.end + current_segment.text += f"{word_info.word}" + current_segment.words.append(word_info) + word_count += 1 + + return current_segment, word_count + + def _determine_major_segment_confidence( + self, segment: AsrSegment, transcription_segments: list[AsrSegment] + ) -> tuple[float | None, float | None]: + """Determines the confidence and no_speech_confidence based on the major segment (which contributes the most time or words). + + Args: + segment (AsrSegment): New ASR segment. + transcription_segments (list[AsrSegment]): Original transcription segments with confidence. + + Returns: + tuple[Optional[float], Optional[float]]: Confidence and no_speech_confidence from the major segment. + """ + + def find_closest_segment( + word_start: float, word_end: float + ) -> AsrSegment | None: + """Finds the closest segment in the transcription for the given word start and end times.""" + closest_segment = min( + transcription_segments, + key=lambda segment: abs(segment.time_interval.start - word_start) + + abs(segment.time_interval.end - word_end), + default=None, + ) + return closest_segment + + def update_segment_contribution( + contributions: dict, segment: AsrSegment, word_duration: float + ) -> None: + """Updates the contribution data for the given segment.""" + segment_id = id(segment) + if segment_id not in contributions: + contributions[segment_id] = { + "segment": segment, + "contribution_time": 0.0, + "word_count": 0, + } + contributions[segment_id]["contribution_time"] += word_duration + contributions[segment_id]["word_count"] += 1 - word_speaker_mapping = align_with_punctuation(speaker_labelled_transcription) + segment_contributions: defaultdict = defaultdict( + lambda: {"segment": None, "contribution_time": 0.0, "word_count": 0} + ) - # 3. Create ASR segments by combining the AsrWord with speaker information + for word in segment.words: + word_start, word_end = word.time_interval.start, word.time_interval.end + word_duration = word_end - word_start - # a. Create speaker segments from new word_speaker_mapping - # b. Limits its length (default 50 words) + closest_segment = find_closest_segment(word_start, word_end) - # a & b - segments = create_speaker_segments(word_speaker_mapping) + if closest_segment: + update_segment_contribution( + segment_contributions, closest_segment, word_duration + ) - # c. Assign new confidence and no_speech_confidence to new segments + if not segment_contributions: + return None, None - segments = add_segment_variables(segments, transcription) + # Determine the segment with the highest word count or contribution time + major_segment_data = max( + segment_contributions.values(), + key=lambda data: data[ + "word_count" + ], # Change this to 'contribution_time' if needed + ) - # Optional: Merge consecutive speaker segments - if merge: - segments = merge_consecutive_speaker_segments(segments) + major_segment = major_segment_data["segment"] + return major_segment.confidence, major_segment.no_speech_confidence + + def _combine_homeogeneous_speaker_asr_segments( + self, + segments: list[AsrSegment], + ) -> list[AsrSegment]: + """Merges consecutive segments that have the same speaker into a single segment. + + Args: + segments (list[AsrSegment]): The initial list of segments. + + Returns: + merged_segments (list[AsrSegment]): A new list of merged segments. + """ + if not segments: + return [] + + merged_segments: list[AsrSegment] = [] + mapping: defaultdict[str, str] = defaultdict(str) + + current_segment = segments[0] + speaker_counter = 0 + + for next_segment in segments[1:]: + if next_segment.speaker == current_segment.speaker: + # Merge segments + current_segment.time_interval.end = next_segment.time_interval.end + current_segment.text += f" {next_segment.text}" + current_segment.words.extend(next_segment.words) + else: + # Assign unique speaker labels and finalize the current segment + if current_segment.speaker: + current_segment.speaker = mapping.setdefault( + current_segment.speaker, f"SPEAKER_{speaker_counter:02d}" + ) + for word in current_segment.words: + word.speaker = current_segment.speaker + merged_segments.append(current_segment) + current_segment = next_segment + speaker_counter += 1 + + # Handle the last segment + if current_segment.speaker: + current_segment.speaker = mapping.setdefault( + current_segment.speaker, f"SPEAKER_{speaker_counter:02d}" + ) + for word in current_segment.words: + word.speaker = current_segment.speaker + merged_segments.append(current_segment) - transcription["segments"] = segments - return transcription + return merged_segments # speaker diarization model occationally produce overlapping chunks/ same speaker segments, # below function combines them properly -def combine_homogeneous_speaker_segs( - diarized_output: SpeakerDiarizationOutput, -) -> SpeakerDiarizationOutput: +def combine_homogeneous_speaker_diarization_segments( + diarized_segments: list[SpeakerDiarizationSegment], +) -> list[SpeakerDiarizationSegment]: """Combines segments with the same speaker into homogeneous speaker segments, ensuring no overlapping times. Args: - diarized_output (SpeakerDiarizationOutput): Input with segments that may have overlapping times. + diarized_segments (list(SpeakerDiarizationSegment)): Input with segments that may have overlapping times. Returns: - SpeakerDiarizationOutput: Output with combined homogeneous speaker segments. + list(SpeakerDiarizationSegment): Output with combined homogeneous speaker segments. """ combined_segments: list = [] current_speaker = None current_segment = None - for segment in sorted( - diarized_output["segments"], key=lambda x: x.time_interval.start - ): + for segment in sorted(diarized_segments, key=lambda x: x.time_interval.start): speaker = segment.speaker # If there's a speaker change or current_segment is None, finalize current and start a new one @@ -630,4 +618,4 @@ def combine_homogeneous_speaker_segs( if current_segment: combined_segments.append(current_segment) - return SpeakerDiarizationOutput(segments=combined_segments) + return combined_segments From fecae593390fc6306e6e2fd852f8ae5fb8b0caca Mon Sep 17 00:00:00 2001 From: jiltseb Date: Tue, 24 Sep 2024 15:57:04 +0000 Subject: [PATCH 15/18] wrapping postprocessing to classmethod, adding test, doc changes --- aana/processors/speaker.py | 523 +++++++++--------- aana/tests/units/test_speaker.py | 24 +- docs/pages/model_hub/asr.md | 87 +-- docs/pages/model_hub/speaker_recognition.md | 6 +- docs/reference/processors.md | 2 +- .../diarized_transcription_example.ipynb | 15 +- 6 files changed, 320 insertions(+), 337 deletions(-) diff --git a/aana/processors/speaker.py b/aana/processors/speaker.py index 140e0dff..e2f7276e 100644 --- a/aana/processors/speaker.py +++ b/aana/processors/speaker.py @@ -11,7 +11,7 @@ sentence_ending_punctuations = ".?!" -class ASRPostProcessingForDiarization: +class PostProcessingForDiarizedAsr: """Class to handle post-processing for diarized ASR output by combining diarization and transcription segments. The post-processing involves assigning speaker labels to transcription segments and words, aligning speakers @@ -23,29 +23,13 @@ class ASRPostProcessingForDiarization: merge (bool): Whether to merge the same speaker segments in the final output. """ - def __init__( - self, + @classmethod + def process( + cls, diarized_segments: list[SpeakerDiarizationSegment], transcription_segments: list[AsrSegment], merge: bool = False, - ): - """Initializes the ASRPostProcessingForDiarization class with diarization and transcription segments. - - Args: - diarized_segments (list[SpeakerDiarizationSegment]): Contains speaker diarization segments. - transcription_segments (list[AsrSegment]): Transcription segments. - merge (bool): If True, merges consecutive speaker segments in the final output. Defaults to False. - """ - self.diarized_segments = diarized_segments - self.transcription_segments = transcription_segments - self.merge = merge - - # Check if inputs are valid: - for segment in self.transcription_segments: - if segment.text and not segment.words: - raise ValueError("Word-level timestamps are required for diarized ASR.") # noqa: TRY003 - - def process(self) -> list[AsrSegment]: + ) -> list[AsrSegment]: """Executes the post-processing pipeline that combines diarization and transcription segments. This method performs the following steps: @@ -55,190 +39,100 @@ def process(self) -> list[AsrSegment]: 4. Optionally, merge consecutive speaker segments. 5. Add confidence and no_speech_confidence to the new segments. + Args: + diarized_segments (list[SpeakerDiarizationSegment]): Contains speaker diarization segments. + transcription_segments (list[AsrSegment]): Transcription segments. + merge (bool): If True, merges consecutive speaker segments in the final output. Defaults to False. + Returns: list[AsrSegment]: Updated transcription segments with speaker information per segment and word. """ # intro: validation checks - if not self.transcription_segments or not self.diarized_segments: - return self.transcription_segments + if not transcription_segments or not diarized_segments: + return transcription_segments + + # Check if inputs are valid: + for segment in transcription_segments: + if segment.text and not segment.words: + raise ValueError("Word-level timestamps are required for diarized ASR.") # noqa: TRY003 # 1. Assign speaker labels to each segment and word - speaker_labelled_transcription = self._assign_word_speakers( - self.diarized_segments, self.transcription_segments + speaker_labelled_transcription = cls._assign_word_speakers( + diarized_segments, transcription_segments ) # 2. Align speakers with punctuation - word_speaker_mapping = self._align_with_punctuation( + word_speaker_mapping = cls._align_with_punctuation( speaker_labelled_transcription ) # 3. Create new transcription segments with speaker information - segments = self._create_speaker_segments(word_speaker_mapping) + segments = cls._create_speaker_segments(word_speaker_mapping) # 4. Assign confidence variables to the new segments - segments = self._add_segment_variables(segments, self.transcription_segments) + segments = cls._add_segment_variables(segments, transcription_segments) # Optional: Merge consecutive speaker segments - if self.merge: - segments = self._combine_homeogeneous_speaker_asr_segments(segments) + if merge: + segments = cls._combine_homeogeneous_speaker_asr_segments(segments) return segments - def _assign_word_speakers( - self, - diarized_segments: list[SpeakerDiarizationSegment], - transcription_segments: list[AsrSegment], - fill_nearest: bool = False, + @classmethod + def _create_speaker_segments( + cls, word_list: list[AsrWord], max_words_per_segment: int = 50 ) -> list[AsrSegment]: - """Assigns speaker labels to each segment and word in the transcription based on diarized output. - - Args: - diarized_segments (list[SpeakerDiarizationSegment]): Contains speaker diarization segments. - transcription_segments (list[AsrSegment]): Transcription data with segments, text, and language_info. - fill_nearest (bool): If True, assigns the closest speaker even if there's no positive overlap. Default is False. - - Returns: - transcription_segments (list[AsrSegment]): Transcription updated in-place with the assigned speaker labels. - """ - for segment in transcription_segments: - # Assign speaker to segment - segment.speaker = self._get_speaker_for_interval( - diarized_segments, - segment.time_interval.start, - segment.time_interval.end, - fill_nearest, - ) - - # Assign speakers to words within the segment - if segment.words: - for word in segment.words: - word.speaker = self._get_speaker_for_interval( - diarized_segments, - word.time_interval.start, - word.time_interval.end, - fill_nearest, - ) - - return transcription_segments - - def _get_speaker_for_interval( - self, - sd_segments: list[SpeakerDiarizationSegment], - start_time: float, - end_time: float, - fill_nearest: bool, - ) -> str | None: - """Determines the speaker for a given time interval based on diarized segments. + """Creates speaker segments from a list of words with speaker annotations. Args: - sd_segments (list[SpeakerDiarizationSegment]): List of speaker diarization segments. - start_time (float): Start time of the interval. - end_time (float): End time of the interval. - fill_nearest (bool): If True, selects the closest speaker even with no overlap. + word_list (list[AsrWord]): A list of words with associated speaker and timing details. + max_words_per_segment (int): The maximum number of words per segment. If the segment exceeds this, + it will be split at previous or next sentence-ending punctuation. Returns: - str | None: The identified speaker label, or None if no speaker is found. + list[AsrSegment]: A list of segments where each segment groups words spoken by the same speaker. """ - overlaps = [] - - for sd_segment in sd_segments: - interval_start = sd_segment.time_interval.start - interval_end = sd_segment.time_interval.end - - # Calculate overlap duration - overlap_start = max(start_time, interval_start) - overlap_end = min(end_time, interval_end) - overlap_duration = max(0.0, overlap_end - overlap_start) - - if overlap_duration > 0 or fill_nearest: - distance = float( - min(abs(start_time - interval_end), abs(end_time - interval_start)) - ) - - overlaps.append( - { - "speaker": sd_segment.speaker, - "overlap_duration": overlap_duration, - "distance": distance, - } - ) - - if not overlaps: - return None - else: - # Select the speaker with the maximum overlap duration (or minimal distance) - best_match = max( - overlaps, - key=lambda x: (x["overlap_duration"], -x["distance"]) - if not fill_nearest - else (-x["distance"]), - ) - return best_match["speaker"] - - def _get_first_word_idx_of_sentence( - self, - word_idx: int, - word_list: list[str], - speaker_list: list[str | None], - max_words: int, - ) -> int: - """Get the index of the first word of the sentence in the given range.""" - left_idx = word_idx - while ( - left_idx > 0 - and word_idx - left_idx < max_words - and speaker_list[left_idx - 1] == speaker_list[left_idx] - and word_list[left_idx - 1][-1] not in sentence_ending_punctuations - ): - left_idx -= 1 - - return ( - left_idx - if left_idx == 0 - or word_list[left_idx - 1][-1] in sentence_ending_punctuations - else -1 - ) + if not word_list: + return [] - def _get_last_word_idx_of_sentence( - self, word_idx: int, word_list: list[str], max_words: int - ) -> int: - """Get the index of the last word of the sentence in the given range.""" - right_idx = word_idx - while ( - right_idx < len(word_list) - and right_idx - word_idx < max_words - and word_list[right_idx][-1] not in sentence_ending_punctuations - ): - right_idx += 1 + current_speaker = None + current_segment = None + final_segments: list[AsrSegment] = [] + word_count = 0 - return ( - right_idx - if right_idx == len(word_list) - 1 - or word_list[right_idx][-1] in sentence_ending_punctuations - else -1 - ) + for word_info in word_list: + speaker = word_info.speaker or current_speaker - def _find_nearest_speaker( - self, index: int, word_speaker_mapping: list[AsrWord], reverse: bool = False - ) -> str | None: - """Find the nearest speaker label in the word_speaker_mapping either forward or backward. + # Handle speaker change + if speaker != current_speaker: + if current_segment: + final_segments.append(current_segment) + current_segment = cls._create_new_segment(word_info, speaker) + current_speaker = speaker + word_count = 1 + else: + if current_segment: + # Handle word count and punctuation splitting + ( + current_segment, + word_count, + ) = cls._split_segment_on_length_punctuation( + current_segment, + word_info, + word_count, + max_words_per_segment, + final_segments, + ) - Args: - index (int): The index to start searching from. - word_speaker_mapping (list[AsrWord]): List of word-speaker mappings. - reverse (bool): Search backwards if True; forwards if False. Default is False. + # Add the final segment if it exists + if current_segment and current_segment.words: + final_segments.append(current_segment) - Returns: - str | None: The nearest speaker found or None if not found. - """ - step = -1 if reverse else 1 - for i in range(index, len(word_speaker_mapping) if not reverse else -1, step): - if word_speaker_mapping[i].speaker: - return word_speaker_mapping[i].speaker - return None + return final_segments + @classmethod def _align_with_punctuation( - self, transcription_segments: list[AsrSegment], max_words_in_sentence: int = 50 + cls, transcription_segments: list[AsrSegment], max_words_in_sentence: int = 50 ) -> list[AsrWord]: """Aligns speaker labels with sentence boundaries defined by punctuation. @@ -257,7 +151,7 @@ def _align_with_punctuation( # Fill missing speaker labels by finding the nearest speaker for i, item in enumerate(word_speaker_mapping): if item.speaker is None: - item.speaker = self._find_nearest_speaker( + item.speaker = cls._find_nearest_speaker( i, word_speaker_mapping, reverse=i > 0 ) speaker_list[i] = item.speaker @@ -270,10 +164,10 @@ def _align_with_punctuation( and speaker_list[k] != speaker_list[k + 1] and words_list[k][-1] not in sentence_ending_punctuations ): - left_idx = self._get_first_word_idx_of_sentence( + left_idx = cls._get_first_word_idx_of_sentence( k, words_list, speaker_list, max_words_in_sentence ) - right_idx = self._get_last_word_idx_of_sentence( + right_idx = cls._get_last_word_idx_of_sentence( k, words_list, max_words_in_sentence - (k - left_idx) ) @@ -298,86 +192,9 @@ def _align_with_punctuation( return word_speaker_mapping - def _create_new_segment( - self, word_info: AsrWord, speaker: str | None, is_empty: bool = False - ) -> AsrSegment: - """Creates a new segment based on word information. - - Args: - word_info (AsrWord): The word information containing text, timing, etc. - speaker (str | None): The speaker associated with this word. - is_empty (bool): If True, creates an empty segment (for punctuation-only segments). - - Returns: - AsrSegment: A new segment with the provided word information and speaker details. - """ - return AsrSegment( - time_interval=TimeInterval( - start=word_info.time_interval.start - if not is_empty - else word_info.time_interval.end, - end=word_info.time_interval.end, - ), - text=word_info.word if not is_empty else "", - speaker=speaker, - words=[word_info] if not is_empty else [], - confidence=None, - no_speech_confidence=None, - ) - - def _create_speaker_segments( - self, word_list: list[AsrWord], max_words_per_segment: int = 50 - ) -> list[AsrSegment]: - """Creates speaker segments from a list of words with speaker annotations. - - Args: - word_list (list[AsrWord]): A list of words with associated speaker and timing details. - max_words_per_segment (int): The maximum number of words per segment. If the segment exceeds this, - it will be split at previous or next sentence-ending punctuation. - - Returns: - list[AsrSegment]: A list of segments where each segment groups words spoken by the same speaker. - """ - if not word_list: - return [] - - current_speaker = None - current_segment = None - final_segments: list[AsrSegment] = [] - word_count = 0 - - for word_info in word_list: - speaker = word_info.speaker or current_speaker - - # Handle speaker change - if speaker != current_speaker: - if current_segment: - final_segments.append(current_segment) - current_segment = self._create_new_segment(word_info, speaker) - current_speaker = speaker - word_count = 1 - else: - if current_segment: - # Handle word count and punctuation splitting - ( - current_segment, - word_count, - ) = self._split_segment_on_length_punctuation( - current_segment, - word_info, - word_count, - max_words_per_segment, - final_segments, - ) - - # Add the final segment if it exists - if current_segment and current_segment.words: - final_segments.append(current_segment) - - return final_segments - + @classmethod def _add_segment_variables( - self, segments: list[AsrSegment], transcription_segments: list[AsrSegment] + cls, segments: list[AsrSegment], transcription_segments: list[AsrSegment] ) -> list[AsrSegment]: """Adds confidence and no_speech_confidence variables to each segment. @@ -389,15 +206,16 @@ def _add_segment_variables( list[AsrSegment]: Segments with confidence and no_speech_confidence added. """ for segment in segments: - confidence, no_speech_confidence = self._determine_major_segment_confidence( + confidence, no_speech_confidence = cls._determine_major_segment_confidence( segment, transcription_segments ) segment.confidence = confidence segment.no_speech_confidence = no_speech_confidence return segments + @classmethod def _split_segment_on_length_punctuation( - self, + cls, current_segment: AsrSegment, word_info: AsrWord, word_count: int, @@ -425,7 +243,7 @@ def _split_segment_on_length_punctuation( current_segment.text += f"{word_info.word}" current_segment.words.append(word_info) final_segments.append(current_segment) - current_segment = self._create_new_segment( + current_segment = cls._create_new_segment( word_info, current_segment.speaker, is_empty=True ) word_count = 0 # Reset word count @@ -439,8 +257,195 @@ def _split_segment_on_length_punctuation( return current_segment, word_count + @staticmethod + def _assign_word_speakers( + diarized_segments: list[SpeakerDiarizationSegment], + transcription_segments: list[AsrSegment], + fill_nearest: bool = False, + ) -> list[AsrSegment]: + """Assigns speaker labels to each segment and word in the transcription based on diarized output. + + Args: + diarized_segments (list[SpeakerDiarizationSegment]): Contains speaker diarization segments. + transcription_segments (list[AsrSegment]): Transcription data with segments, text, and language_info. + fill_nearest (bool): If True, assigns the closest speaker even if there's no positive overlap. Default is False. + + Returns: + transcription_segments (list[AsrSegment]): Transcription updated in-place with the assigned speaker labels. + """ + + def get_speaker_for_interval( + sd_segments: list[SpeakerDiarizationSegment], + start_time: float, + end_time: float, + fill_nearest: bool, + ) -> str | None: + """Determines the speaker for a given time interval based on diarized segments. + + Args: + sd_segments (list[SpeakerDiarizationSegment]): List of speaker diarization segments. + start_time (float): Start time of the interval. + end_time (float): End time of the interval. + fill_nearest (bool): If True, selects the closest speaker even with no overlap. + + Returns: + str | None: The identified speaker label, or None if no speaker is found. + """ + overlaps = [] + + for sd_segment in sd_segments: + interval_start = sd_segment.time_interval.start + interval_end = sd_segment.time_interval.end + + # Calculate overlap duration + overlap_start = max(start_time, interval_start) + overlap_end = min(end_time, interval_end) + overlap_duration = max(0.0, overlap_end - overlap_start) + + if overlap_duration > 0 or fill_nearest: + distance = float( + min( + abs(start_time - interval_end), + abs(end_time - interval_start), + ) + ) + + overlaps.append( + { + "speaker": sd_segment.speaker, + "overlap_duration": overlap_duration, + "distance": distance, + } + ) + + if not overlaps: + return None + else: + # Select the speaker with the maximum overlap duration (or minimal distance) + best_match = max( + overlaps, + key=lambda x: (x["overlap_duration"], -x["distance"]) + if not fill_nearest + else (-x["distance"]), + ) + return best_match["speaker"] + + for segment in transcription_segments: + # Assign speaker to segment + segment.speaker = get_speaker_for_interval( + diarized_segments, + segment.time_interval.start, + segment.time_interval.end, + fill_nearest, + ) + + # Assign speakers to words within the segment + if segment.words: + for word in segment.words: + word.speaker = get_speaker_for_interval( + diarized_segments, + word.time_interval.start, + word.time_interval.end, + fill_nearest, + ) + + return transcription_segments + + @staticmethod + def _get_first_word_idx_of_sentence( + word_idx: int, + word_list: list[str], + speaker_list: list[str | None], + max_words: int, + ) -> int: + """Get the index of the first word of the sentence in the given range.""" + left_idx = word_idx + while ( + left_idx > 0 + and word_idx - left_idx < max_words + and speaker_list[left_idx - 1] == speaker_list[left_idx] + and word_list[left_idx - 1][-1] not in sentence_ending_punctuations + ): + left_idx -= 1 + + return ( + left_idx + if left_idx == 0 + or word_list[left_idx - 1][-1] in sentence_ending_punctuations + else -1 + ) + + @staticmethod + def _get_last_word_idx_of_sentence( + word_idx: int, word_list: list[str], max_words: int + ) -> int: + """Get the index of the last word of the sentence in the given range.""" + right_idx = word_idx + while ( + right_idx < len(word_list) + and right_idx - word_idx < max_words + and word_list[right_idx][-1] not in sentence_ending_punctuations + ): + right_idx += 1 + + return ( + right_idx + if right_idx == len(word_list) - 1 + or word_list[right_idx][-1] in sentence_ending_punctuations + else -1 + ) + + @staticmethod + def _find_nearest_speaker( + index: int, word_speaker_mapping: list[AsrWord], reverse: bool = False + ) -> str | None: + """Find the nearest speaker label in the word_speaker_mapping either forward or backward. + + Args: + index (int): The index to start searching from. + word_speaker_mapping (list[AsrWord]): List of word-speaker mappings. + reverse (bool): Search backwards if True; forwards if False. Default is False. + + Returns: + str | None: The nearest speaker found or None if not found. + """ + step = -1 if reverse else 1 + for i in range(index, len(word_speaker_mapping) if not reverse else -1, step): + if word_speaker_mapping[i].speaker: + return word_speaker_mapping[i].speaker + return None + + @staticmethod + def _create_new_segment( + word_info: AsrWord, speaker: str | None, is_empty: bool = False + ) -> AsrSegment: + """Creates a new segment based on word information. + + Args: + word_info (AsrWord): The word information containing text, timing, etc. + speaker (str | None): The speaker associated with this word. + is_empty (bool): If True, creates an empty segment (for punctuation-only segments). + + Returns: + AsrSegment: A new segment with the provided word information and speaker details. + """ + return AsrSegment( + time_interval=TimeInterval( + start=word_info.time_interval.start + if not is_empty + else word_info.time_interval.end, + end=word_info.time_interval.end, + ), + text=word_info.word if not is_empty else "", + speaker=speaker, + words=[word_info] if not is_empty else [], + confidence=None, + no_speech_confidence=None, + ) + + @staticmethod def _determine_major_segment_confidence( - self, segment: AsrSegment, transcription_segments: list[AsrSegment] + segment: AsrSegment, transcription_segments: list[AsrSegment] ) -> tuple[float | None, float | None]: """Determines the confidence and no_speech_confidence based on the major segment (which contributes the most time or words). @@ -507,8 +512,8 @@ def update_segment_contribution( major_segment = major_segment_data["segment"] return major_segment.confidence, major_segment.no_speech_confidence + @staticmethod def _combine_homeogeneous_speaker_asr_segments( - self, segments: list[AsrSegment], ) -> list[AsrSegment]: """Merges consecutive segments that have the same speaker into a single segment. diff --git a/aana/tests/units/test_speaker.py b/aana/tests/units/test_speaker.py index bd84f684..ed600874 100644 --- a/aana/tests/units/test_speaker.py +++ b/aana/tests/units/test_speaker.py @@ -8,14 +8,14 @@ from aana.core.models.asr import AsrSegment from aana.core.models.speaker import SpeakerDiarizationSegment -from aana.processors.speaker import ASRPostProcessingForDiarization +from aana.processors.speaker import PostProcessingForDiarizedAsr from aana.tests.utils import verify_deployment_results @pytest.mark.parametrize("audio_file", ["sd_sample.wav"]) def test_asr_diarization_post_process(audio_file: Literal["sd_sample.wav"]): - """Test that the ASR output can be processed to generate diarized transcription.""" - # load precomputed ASR and Diarization outputs + """Test that the ASR output can be processed to generate diarized transcription and an invalid ASR output leads to ValueError.""" + # Load precomputed ASR and Diarization outputs asr_path = ( resources.files("aana.tests.files.expected.whisper") / f"whisper_medium_{audio_file}.json" @@ -44,9 +44,21 @@ def test_asr_diarization_post_process(audio_file: Literal["sd_sample.wav"]): SpeakerDiarizationSegment.model_validate(segment) for segment in diar_op["segments"] ] - post_processor = ASRPostProcessingForDiarization( + asr_op["segments"] = PostProcessingForDiarizedAsr.process( diarized_segments=diarized_segments, transcription_segments=asr_segments ) - asr_op["segments"] = post_processor.process() - verify_deployment_results(expected_results_path, asr_op) + + # Raise error if the ASR output is a invalid input for combining with diarization. + + # setting words to empty list + for segment in asr_segments: + segment.words = [] + + # Expect ValueError with the specific error message + with pytest.raises( + ValueError, match="Word-level timestamps are required for diarized ASR." + ): + PostProcessingForDiarizedAsr.process( + diarized_segments=diarized_segments, transcription_segments=asr_segments + ) diff --git a/docs/pages/model_hub/asr.md b/docs/pages/model_hub/asr.md index a8d1f513..7146bfbf 100644 --- a/docs/pages/model_hub/asr.md +++ b/docs/pages/model_hub/asr.md @@ -53,64 +53,33 @@ Here are some other possible configurations for the Whisper deployment: ### Diarized ASR -Diarized transcription can be generated by using [WhisperDeployment](./../../reference/deployments.md#aana.deployments.WhisperDeployment) and [PyannoteSpeakerDiarizationDeployment](./../../reference/deployments.md#aana.deployments.PyannoteSpeakerDiarizationDeployment) and combining the timelines using post processing with [ASRPostProcessingForDiarization](./../../reference/processors.md). +Diarized transcription can be generated by using [WhisperDeployment](./../../reference/deployments.md#aana.deployments.WhisperDeployment) and [PyannoteSpeakerDiarizationDeployment](./../../reference/deployments.md#aana.deployments.PyannoteSpeakerDiarizationDeployment) and combining the timelines using post processing with [PostProcessingForDiarizedAsr](./../../reference/processors.md#aana.processors.speaker.PostProcessingForDiarizedAsr). + +Example configuration for the PyannoteSpeakerDiarization model is available at [Speaker Diarization](./speaker_recognition.md/#speaker-diarization-sd-models) model hub. + +You can simply define the model deployments and the endpoint to transcribe the video with diarization. Below code snippet shows the how to combine the outputs from ASR and diarization deployments: + + +```python +from aana.processors.speaker import PostProcessingForDiarizedAsr + +# diarized transcript requires word_timestamps from ASR +whisper_params.word_timestamps = True +transcription = await self.asr_handle.transcribe( + audio=audio, params=whisper_params +) +diarized_output = await self.diar_handle.diarize( + audio=audio, params=diar_params +) +updated_segments = PostProcessingForDiarizedAsr( + diarized_segments=diarized_output["segments"], + transcription_segments=transcription["segments"], +) +output_segments = [ + s.model_dump(include=["text", "time_interval", "speaker"]) + for s in updated_segments +] +``` +An example notebook on diarized transcription is available at [notebooks/diarized_transcription_example.ipynb](https://github.com/mobiusml/aana_sdk/tree/main/notebooks/diarized_transcription_example.ipynb). -Example configuration for the PyannoteSpeakerDiarization model is available at [Speaker Diarization](./speaker_recognition.md) model hub. - -You can simply define the model deployments and the endpoint to transcribe the video with diarization. Below code snippet shows the custom endpoint class `TranscribeVideoWithDiarEndpoint` to combine the outputs from ASR and diarization deployments: - - ```python - from aana.api.api_generation import Endpoint - from aana.core.models.speaker import PyannoteSpeakerDiarizationParams - from aana.core.models.video import VideoInput - from aana.core.models.whisper import WhisperParams - from aana.deployments.whisper_deployment import WhisperOutput - - from aana.deployments.aana_deployment_handle import AanaDeploymentHandle - - from aana.integrations.external.yt_dlp import download_video - from aana.processors.remote import run_remote - from aana.processors.speaker import ASRPostProcessingForDiarization - from aana.processors.video import extract_audio - - class TranscribeVideoWithDiarEndpoint(Endpoint): - """Transcribe video with diarization endpoint.""" - - async def initialize(self): - """Initialize the endpoint.""" - self.asr_handle = await AanaDeploymentHandle.create("asr_deployment") - self.diar_handle = await AanaDeploymentHandle.create("diarization_deployment") - await super().initialize() - - async def run( - self, - video: VideoInput, - whisper_params: WhisperParams, - diar_params: PyannoteSpeakerDiarizationParams, - ) -> WhisperOutput: - """Transcribe video with diarization.""" - video_obj = await run_remote(download_video)(video_input=video) - audio = extract_audio(video=video_obj) - - # diarized transcript requires word_timestamps from ASR - whisper_params.word_timestamps = True - transcription = await self.asr_handle.transcribe( - audio=audio, params=whisper_params - ) - diarized_output = await self.diar_handle.diarize( - audio=audio, params=diar_params - ) - post_processor = ASRPostProcessingForDiarization( - diarized_segments=diarized_output["segments"], - transcription_segments=transcription["segments"], - ) - updated_segments = post_processor.process() - output_segments = [ - s.model_dump(include=["text", "time_interval", "speaker"]) - for s in updated_segments - ] - - return {"segments": output_segments} - ``` -An example notebook on diarized transcription is available at `notebooks/diarized_transcription_example.ipynb`. diff --git a/docs/pages/model_hub/speaker_recognition.md b/docs/pages/model_hub/speaker_recognition.md index 5aa053e4..fd2a4d7d 100644 --- a/docs/pages/model_hub/speaker_recognition.md +++ b/docs/pages/model_hub/speaker_recognition.md @@ -35,7 +35,7 @@ The PyAnnote speaker diarization models are gated, requiring special access. To To get your Hugging Face access token, visit the [Hugging Face Settings - Tokens](https://huggingface.co/settings/tokens). -### Example Configurations +## Example Configurations As an example, let's see how to configure the Pyannote Speaker Diarization deployment for the [Speaker Diarization-3.1 model](https://huggingface.co/pyannote/speaker-diarization-3.1). @@ -55,7 +55,7 @@ As an example, let's see how to configure the Pyannote Speaker Diarization deplo ) ``` -### Diarized ASR +## Diarized ASR -Speaker Diarization output can be combined with ASR to generate transcription with speaker information. Further details and code snippet are available in [ASR model hub](./asr.md). +Speaker Diarization output can be combined with ASR to generate transcription with speaker information. Further details and code snippet are available in [ASR model hub](./asr.md/#diarized-asr). diff --git a/docs/reference/processors.md b/docs/reference/processors.md index a656a6bb..0f456f68 100644 --- a/docs/reference/processors.md +++ b/docs/reference/processors.md @@ -3,4 +3,4 @@ ::: aana.processors.remote ::: aana.processors.video ::: aana.processors.batch -::: aana.processors.speaker \ No newline at end of file +::: aana.processors.speaker.PostProcessingForDiarizedAsr \ No newline at end of file diff --git a/notebooks/diarized_transcription_example.ipynb b/notebooks/diarized_transcription_example.ipynb index 5d0114eb..df7a9812 100644 --- a/notebooks/diarized_transcription_example.ipynb +++ b/notebooks/diarized_transcription_example.ipynb @@ -37,15 +37,13 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024-09-23 11:22:15,770\tWARNING services.py:2017 -- WARNING: The object store is using /tmp instead of /dev/shm because /dev/shm has only 67108864 bytes available. This will harm performance! You may be able to free up space by deleting files in /dev/shm. If you are inside a Docker container, you can increase /dev/shm size by passing '--shm-size=10.24gb' to 'docker run' (or add it to the run_options list in a Ray cluster config). Make sure to set this to more than 30% of available RAM.\n", - "2024-09-23 11:22:16,932\tINFO worker.py:1774 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n", "INFO [alembic.runtime.migration] Context impl SQLiteImpl.\n", "INFO [alembic.runtime.migration] Will assume non-transactional DDL.\n" ] @@ -70,7 +68,7 @@ ")\n", "from aana.integrations.external.yt_dlp import download_video\n", "from aana.processors.remote import run_remote\n", - "from aana.processors.speaker import ASRPostProcessingForDiarization\n", + "from aana.processors.speaker import PostProcessingForDiarizedAsr\n", "from aana.processors.video import extract_audio\n", "from aana.sdk import AanaSDK\n", "\n", @@ -128,11 +126,10 @@ " diarized_output = await self.diar_handle.diarize(\n", " audio=audio, params=diar_params\n", " )\n", - " post_processor = ASRPostProcessingForDiarization(\n", + " updated_segments = PostProcessingForDiarizedAsr.process(\n", " diarized_segments=diarized_output[\"segments\"],\n", " transcription_segments=transcription[\"segments\"],\n", " )\n", - " updated_segments = post_processor.process()\n", " output_segments = [\n", " s.model_dump(include=[\"text\", \"time_interval\", \"speaker\"])\n", " for s in updated_segments\n", @@ -173,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -196,7 +193,7 @@ "\n" ], "text/plain": [ - "Documentation is available at \u001b]8;id=363073;http://127.0.0.1:8000/docs\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/docs\u001b[0m\u001b]8;;\u001b\\ and \u001b]8;id=935227;http://127.0.0.1:8000/redoc\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/redoc\u001b[0m\u001b]8;;\u001b\\\n" + "Documentation is available at \u001b]8;id=619085;http://127.0.0.1:8000/docs\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/docs\u001b[0m\u001b]8;;\u001b\\ and \u001b]8;id=66065;http://127.0.0.1:8000/redoc\u001b\\\u001b[4;94mhttp://127.0.0.1:8000/redoc\u001b[0m\u001b]8;;\u001b\\\n" ] }, "metadata": {}, @@ -216,7 +213,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [ { From 6c3cc55d4b8afe59e1e876035823b49f4edf3682 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Wed, 25 Sep 2024 11:24:33 +0000 Subject: [PATCH 16/18] doc changes and minor fixes --- ...pyannote_speaker_diarization_deployment.py | 4 +- aana/processors/speaker.py | 9 ++-- docs/pages/model_hub/asr.md | 49 +++++++++++++++++-- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/aana/deployments/pyannote_speaker_diarization_deployment.py b/aana/deployments/pyannote_speaker_diarization_deployment.py index 17efc322..5e494e82 100644 --- a/aana/deployments/pyannote_speaker_diarization_deployment.py +++ b/aana/deployments/pyannote_speaker_diarization_deployment.py @@ -79,7 +79,9 @@ async def apply_config(self, config: dict[str, Any]): self.diarize_model.to(torch.device(self.device)) except Exception as e: - raise GatedRepoError from e + raise GatedRepoError( + message=f"This repository is private and requires a token to accept user conditions and access models in {self.model_id} pipeline." + ) from e async def __inference( self, audio: Audio, params: PyannoteSpeakerDiarizationParams diff --git a/aana/processors/speaker.py b/aana/processors/speaker.py index e2f7276e..cc045ce1 100644 --- a/aana/processors/speaker.py +++ b/aana/processors/speaker.py @@ -563,14 +563,13 @@ def _combine_homeogeneous_speaker_asr_segments( return merged_segments -# speaker diarization model occationally produce overlapping chunks/ same speaker segments, -# below function combines them properly - - def combine_homogeneous_speaker_diarization_segments( diarized_segments: list[SpeakerDiarizationSegment], ) -> list[SpeakerDiarizationSegment]: - """Combines segments with the same speaker into homogeneous speaker segments, ensuring no overlapping times. + """Combines same speaker segments in a diarization output. + + Speaker diarization model occationally produce overlapping chunks or segments belonging to the same speaker. + This method combines segments with the same speaker into homogeneous speaker segments, ensuring no overlapping times. Args: diarized_segments (list(SpeakerDiarizationSegment)): Input with segments that may have overlapping times. diff --git a/docs/pages/model_hub/asr.md b/docs/pages/model_hub/asr.md index 7146bfbf..94e0f4b2 100644 --- a/docs/pages/model_hub/asr.md +++ b/docs/pages/model_hub/asr.md @@ -63,6 +63,9 @@ You can simply define the model deployments and the endpoint to transcribe the v ```python from aana.processors.speaker import PostProcessingForDiarizedAsr +# 1. create ASR and Speaker Diarization deployments +# 2. Initilaize the endpoint with self.asr_handle and self.diar_handle + # diarized transcript requires word_timestamps from ASR whisper_params.word_timestamps = True transcription = await self.asr_handle.transcribe( @@ -71,14 +74,50 @@ transcription = await self.asr_handle.transcribe( diarized_output = await self.diar_handle.diarize( audio=audio, params=diar_params ) -updated_segments = PostProcessingForDiarizedAsr( +updated_segments = PostProcessingForDiarizedAsr.process( diarized_segments=diarized_output["segments"], transcription_segments=transcription["segments"], ) -output_segments = [ - s.model_dump(include=["text", "time_interval", "speaker"]) - for s in updated_segments -] + +# updated_segments will have speaker information as well: + +# "segments": [ +# { +# "text": " Hello. Hello.", +# "time_interval": { +# "start": 6.38, +# "end": 7.84 +# }, +# "confidence": 0.8329984157521475, +# "no_speech_confidence": 0.012033582665026188, +# "words": [ +# { +# "word": " Hello.", +# "speaker": "SPEAKER_01", +# "time_interval": { +# "start": 6.38, +# "end": 7.0 +# }, +# "alignment_confidence": 0.6853185296058655 +# }, +# ... +# ], +# "speaker": "SPEAKER_01" +# }, +# { +# "text": " Oh, hello. I didn't know you were there.", +# "time_interval": { +# "start": 8.3, +# "end": 9.68 +# }, +# "confidence": 0.8329984157521475, +# "no_speech_confidence": 0.012033582665026188, +# "words": [... +# ], +# "speaker": "SPEAKER_02" +# }, +# ...] + ``` An example notebook on diarized transcription is available at [notebooks/diarized_transcription_example.ipynb](https://github.com/mobiusml/aana_sdk/tree/main/notebooks/diarized_transcription_example.ipynb). From 8682d74a0c9b4565f0a436e4bc0337c95141745e Mon Sep 17 00:00:00 2001 From: jiltseb Date: Wed, 25 Sep 2024 14:36:40 +0000 Subject: [PATCH 17/18] minor doc edits --- docs/pages/model_hub/asr.md | 4 +++- docs/pages/model_hub/speaker_recognition.md | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/pages/model_hub/asr.md b/docs/pages/model_hub/asr.md index 94e0f4b2..0d67a7f6 100644 --- a/docs/pages/model_hub/asr.md +++ b/docs/pages/model_hub/asr.md @@ -62,6 +62,7 @@ You can simply define the model deployments and the endpoint to transcribe the v ```python from aana.processors.speaker import PostProcessingForDiarizedAsr +from aana.core.models.base import pydantic_to_dict # 1. create ASR and Speaker Diarization deployments # 2. Initilaize the endpoint with self.asr_handle and self.diar_handle @@ -78,8 +79,9 @@ updated_segments = PostProcessingForDiarizedAsr.process( diarized_segments=diarized_output["segments"], transcription_segments=transcription["segments"], ) +output = pydantic_to_dict(updated_segments) -# updated_segments will have speaker information as well: +# output will have speaker information as well: # "segments": [ # { diff --git a/docs/pages/model_hub/speaker_recognition.md b/docs/pages/model_hub/speaker_recognition.md index fd2a4d7d..bcd0bfc3 100644 --- a/docs/pages/model_hub/speaker_recognition.md +++ b/docs/pages/model_hub/speaker_recognition.md @@ -19,7 +19,7 @@ The PyAnnote speaker diarization models are gated, requiring special access. To use these models: 1. **Request Access**: - Visit the [PyAnnote Speaker Diarization 3.1 model page](https://huggingface.co/pyannote/speaker-diarization-3.1) on Hugging Face. Log in, fil out the form, and request access. + Visit the [PyAnnote Speaker Diarization 3.1 model page](https://huggingface.co/pyannote/speaker-diarization-3.1) and [Pyannote Speaker Segmentation 3.0 model page](https://huggingface.co/pyannote/segmentation-3.0) on Hugging Face. Log in, fil out the forms, and request access. 2. **Approval**: - If automatic, access is granted immediately. From 35364bb30d500eee0f828495ba2159369898fe48 Mon Sep 17 00:00:00 2001 From: jiltseb Date: Thu, 26 Sep 2024 07:39:07 +0000 Subject: [PATCH 18/18] added example output, comments --- docs/pages/model_hub/asr.md | 67 ++++++++++++++----------------------- 1 file changed, 25 insertions(+), 42 deletions(-) diff --git a/docs/pages/model_hub/asr.md b/docs/pages/model_hub/asr.md index 0d67a7f6..63e88218 100644 --- a/docs/pages/model_hub/asr.md +++ b/docs/pages/model_hub/asr.md @@ -64,61 +64,44 @@ You can simply define the model deployments and the endpoint to transcribe the v from aana.processors.speaker import PostProcessingForDiarizedAsr from aana.core.models.base import pydantic_to_dict -# 1. create ASR and Speaker Diarization deployments -# 2. Initilaize the endpoint with self.asr_handle and self.diar_handle # diarized transcript requires word_timestamps from ASR whisper_params.word_timestamps = True + +# asr_handle is an AanaDeploymentHandle for WhisperDeployment transcription = await self.asr_handle.transcribe( audio=audio, params=whisper_params ) + +# diar_handle is an AanaDeploymentHandle for PyannoteSpeakerDiarizationDeployment diarized_output = await self.diar_handle.diarize( audio=audio, params=diar_params ) + updated_segments = PostProcessingForDiarizedAsr.process( diarized_segments=diarized_output["segments"], transcription_segments=transcription["segments"], ) -output = pydantic_to_dict(updated_segments) - -# output will have speaker information as well: - -# "segments": [ -# { -# "text": " Hello. Hello.", -# "time_interval": { -# "start": 6.38, -# "end": 7.84 -# }, -# "confidence": 0.8329984157521475, -# "no_speech_confidence": 0.012033582665026188, -# "words": [ -# { -# "word": " Hello.", -# "speaker": "SPEAKER_01", -# "time_interval": { -# "start": 6.38, -# "end": 7.0 -# }, -# "alignment_confidence": 0.6853185296058655 -# }, -# ... -# ], -# "speaker": "SPEAKER_01" -# }, -# { -# "text": " Oh, hello. I didn't know you were there.", -# "time_interval": { -# "start": 8.3, -# "end": 9.68 -# }, -# "confidence": 0.8329984157521475, -# "no_speech_confidence": 0.012033582665026188, -# "words": [... -# ], -# "speaker": "SPEAKER_02" -# }, -# ...] + +# updated_segments will have speaker information as well: + +# [AsrSegment(text=' Hello. Hello.', +# time_interval=TimeInterval(start=6.38, end=7.84), +# confidence=0.8329984157521475, +# no_speech_confidence=0.012033582665026188, +# words=[AsrWord(word=' Hello.', speaker='SPEAKER_01',time_interval=TimeInterval(start=6.38, end=7.0), alignment_confidence=0.6853185296058655), +# AsrWord(word=' Hello.', speaker='SPEAKER_01', time_interval=TimeInterval(start=7.5, end=7.84), alignment_confidence=0.7124693989753723)], +# speaker='SPEAKER_01'), +# +# AsrSegment(text=" Oh, hello. I didn't know you were there.", +# time_interval=TimeInterval(start=8.3, end=9.68), +# confidence=0.8329984157521475, +# no_speech_confidence=0.012033582665026188, +# words=[AsrWord(word=' Oh,', speaker='SPEAKER_02', time_interval=TimeInterval(start=8.3, end=8.48), alignment_confidence=0.8500092029571533), +# AsrWord(word=' hello.', speaker='SPEAKER_02', time_interval=TimeInterval(start=8.5, end=8.76), alignment_confidence=0.9408962726593018), ...], +# speaker='SPEAKER_02'), +# ... +# ] ``` An example notebook on diarized transcription is available at [notebooks/diarized_transcription_example.ipynb](https://github.com/mobiusml/aana_sdk/tree/main/notebooks/diarized_transcription_example.ipynb).