Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend audio ds_tool #113

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
118 changes: 111 additions & 7 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import datasets
import jinja2
import numpy as np
import openai
import simple_parsing
import yaml
Expand Down Expand Up @@ -175,6 +176,103 @@ def _map_sample(self, sample, exclude_fields):
return sample


@dataclasses.dataclass
class AudioExtensionTask:
audio_column_name: str = simple_parsing.field(default="audio", alias="-a")
text_column_name: str = simple_parsing.field(default="sentence", alias="-A")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
text_column_name: str = simple_parsing.field(default="sentence", alias="-A")
text_column_name: str = simple_parsing.field(default="sentence", alias="-t")

translation_column_name: Optional[str] = simple_parsing.field(
default="translation", alias="-T"
)
id_column_name: str = simple_parsing.field(default="id", alias="-i")
extend_type: str = simple_parsing.field(
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
default="repeat", alias="-o", choices=["repeat", "combine"]
)
multiplier: int = simple_parsing.field(default=2, alias="-m")

def map_split(
self,
ds_split: datasets.Dataset,
num_proc: int,
writer_batch_size: int,
exclude_fields: List[str],
) -> datasets.Dataset:
print(
f'Extending audio using "{self.extend_type}" method with multiplier {self.multiplier}'
)

if self.extend_type == "repeat":
return ds_split.map(
function=self._map_sample_repeat,
num_proc=num_proc,
writer_batch_size=writer_batch_size,
remove_columns=ds_split.column_names,
)
elif self.extend_type == "combine":
return ds_split.map(
function=self._map_batch_combine,
batched=True,
batch_size=self.multiplier,
num_proc=num_proc,
writer_batch_size=writer_batch_size,
remove_columns=ds_split.column_names,
)
else:
raise ValueError(f"Unknown extend_type: {self.extend_type}")

def _map_sample_repeat(self, sample):
audio = sample[self.audio_column_name]

new_audio = {
"sampling_rate": audio["sampling_rate"],
"array": np.tile(audio["array"], self.multiplier),
}

new_sample = {
self.audio_column_name: new_audio,
self.text_column_name: " ".join(
[sample[self.text_column_name]] * self.multiplier
),
self.id_column_name: sample[self.id_column_name],
}

if self.translation_column_name is not None:
translation = sample.get(self.translation_column_name)
if translation is not None:
new_sample[self.translation_column_name] = " ".join(
[translation] * self.multiplier
)

return new_sample

def _map_batch_combine(self, batch):
audios = batch[self.audio_column_name]
sentences = batch[self.text_column_name]
ids = batch[self.id_column_name]

combined_audio = {
"sampling_rate": audios[0]["sampling_rate"],
"array": np.concatenate([audio["array"] for audio in audios]),
}
combined_sentences = " ".join(sentences)
combined_ids = "+".join(ids)

new_batch = {
self.audio_column_name: [combined_audio],
self.text_column_name: [combined_sentences],
self.id_column_name: [combined_ids],
}

if self.translation_column_name in batch:
translations = batch[self.translation_column_name]
if translations is not None and all(
translation is not None for translation in translations
):
combined_translations = " ".join(translations)
new_batch[self.translation_column_name] = [combined_translations]

return new_batch


# This script is used to either generate audio samples from text using a TTS model, or to generate text samples using a text generation model.
# just ds_tool tts -d google/boolq -u fixie-ai/boolq-audio -T {{question}} -a audio --token $HF_WRITE_TOKEN
# just ds_tool textgen -d fixie-ai/boolq-audio -u fixie-ai/bar -T {{explanation}} -b https://api.fireworks.ai/inference/v1 -k $FIREWORKS_API_KEY -m accounts/fireworks/models/llama-v3-8b-instruct
Expand Down Expand Up @@ -218,10 +316,12 @@ class DatasetToolArgs:
default_factory=lambda: ["audio"]
)

task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups(
{"tts": TtsTask, "textgen": TextGenerationTask}, # type: ignore
default_factory=TtsTask,
positional=True,
task: Union[TtsTask, TextGenerationTask, AudioExtensionTask] = (
simple_parsing.subgroups(
{"tts": TtsTask, "textgen": TextGenerationTask, "audioext": AudioExtensionTask}, # type: ignore
default_factory=TtsTask,
positional=True,
)
)

def __post_init__(self):
Expand Down Expand Up @@ -308,9 +408,9 @@ def process_and_upload_split_rescursive(
self.chunks_not_uploaded.append((start_index, end_index))
return None
failed_chunk_ranges.append((chunk_start, chunk_end))
successful_chunks = self.args.num_chunks - len(failed_chunk_ranges)
successful_chunks = total_chunks - len(failed_chunk_ranges)
print(
f"Finished processing and uploading {successful_chunks}/{self.args.num_chunks} chunks for range [{start_index}, {end_index})"
f"Finished processing and uploading {successful_chunks}/{total_chunks} chunks for range [{start_index}, {end_index})"
)
if len(failed_chunk_ranges) > 0:
for start, end in failed_chunk_ranges:
Expand Down Expand Up @@ -358,7 +458,11 @@ def _upload(self, ds_chunk_processed: datasets.Dataset, data_dir: str, split_nam
"split": split_name,
}
assert isinstance(self.args.upload_name, str)
ds_split_chunked.push_to_hub(self.args.upload_name, **hub_args)
try:
ds_split_chunked.push_to_hub(self.args.upload_name, **hub_args)
except Exception as e:
print(f"Failed to upload chunk to hub: {e}")
raise e


def main(args: DatasetToolArgs):
Expand Down
Loading