Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chunking to ds_tool #97

Merged
merged 31 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
81bb206
Added chunking
liPatrick Aug 23, 2024
3ba9ff9
Dynamic chunking
liPatrick Aug 23, 2024
2ce70f3
Raise error in jinja template
liPatrick Aug 23, 2024
168275b
Format errors
liPatrick Aug 23, 2024
daa2ce9
Fix test
liPatrick Aug 23, 2024
31a8a13
Return sample when format is wrong
liPatrick Aug 23, 2024
b1f6d67
Remove template failures counter, doenst work with multi-proc
liPatrick Aug 26, 2024
4163628
Addressing comments
liPatrick Aug 26, 2024
255c46d
Handle text proc asr error in dataset.py
liPatrick Aug 26, 2024
f1236ce
removing extra prints
liPatrick Aug 26, 2024
eab6b80
Make process upload split recurisve
liPatrick Aug 26, 2024
4e42d58
Add more comments
liPatrick Aug 27, 2024
93c25bf
More comments
liPatrick Aug 27, 2024
1479d91
Use None instead of empty quotes. Type issue resolved
liPatrick Aug 27, 2024
e5ae10f
Chunked dataset subclass
liPatrick Aug 27, 2024
dce349e
Merge branch 'main' into patrick/chunking-ds_tool
liPatrick Aug 29, 2024
ded2f29
HF readme integration
liPatrick Sep 4, 2024
b7b2e80
format
liPatrick Sep 4, 2024
5f70e0f
Add dataset version to load
liPatrick Sep 4, 2024
653e3b0
Remove empty audio
liPatrick Sep 4, 2024
9c66031
Moved asr error try catch to get_transcribe_sample
liPatrick Sep 5, 2024
230f962
Remove total samples processed
liPatrick Sep 5, 2024
ac9bc03
Change continuation to text
liPatrick Sep 5, 2024
0adf5a6
Some fixes
liPatrick Sep 6, 2024
6ebadb0
Fix more bugs
liPatrick Sep 6, 2024
0d08712
Address comments
liPatrick Sep 6, 2024
d358768
Address comments
liPatrick Sep 6, 2024
3f8a0db
Fix import format
liPatrick Sep 6, 2024
f60b90e
Extra filter method
liPatrick Sep 6, 2024
cdc6dee
Check empty columns filter
liPatrick Sep 6, 2024
bddf33c
Add empty column check
liPatrick Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,11 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
)
self._init_dataset(dataset)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
try:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)
except text_proc.FormatASRError:
return None


class EmptyDataset(data.IterableDataset):
Expand Down Expand Up @@ -776,8 +779,11 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
ds = ds.shuffle(seed=self._args.shuffle_seed)
self._init_dataset(ds)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
try:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)
except text_proc.FormatASRError:
return None


# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future.
Expand All @@ -796,8 +802,11 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
)
self._init_dataset(dataset)

def _get_sample(self, row) -> VoiceSample:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)
def _get_sample(self, row) -> Optional[VoiceSample]:
try:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)
except text_proc.FormatASRError:
return None
liPatrick marked this conversation as resolved.
Show resolved Hide resolved


# TODO: this dataset can be replaced with GenericVoiceDataset and will be removed/updated in the future.
Expand Down
13 changes: 10 additions & 3 deletions ultravox/data/text_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import nltk # needed for truecase
import truecase


class FormatASRError(ValueError):
pass


# only in master thread per node to avoid
# other threads overwriting the downloaded .zip
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
Expand All @@ -31,15 +36,17 @@ def format_asr_text(text: str) -> str:
remaining_words = []
for word in text.split():
if word in GIGASPEECH_GARBAGE_UTTERANCE_TAGS:
return ""
raise FormatASRError(f"Garbage utterance tag found: {word}")
if word in GIGASPEECH_PUNCTUATIONS:
word = GIGASPEECH_PUNCTUATIONS[word]
remaining_words.append(word)

text = " ".join(remaining_words)
text = truecase.get_true_case(text)

return text.strip()
text_stripped = text.strip()
if len(text_stripped) == 0:
raise FormatASRError("Empty text after processing")
return text_stripped


CONVERSATIONAL_FILLER = [
Expand Down
9 changes: 6 additions & 3 deletions ultravox/data/text_proc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
"I SEE LOTS OF PEOPLE HAVE DRONES HERE <COMMA> MAVERICK AS WELL <PERIOD> ",
"I see lots of people have drones here, maverick as well.",
),
# truecase messes with the case of special tags too, but we probably don't care about that
("<NOISE> OH WHAT WAS THAT?", ""),
],
)
def test_no_space_punctuation(text, expected):
def test_format_asr_text(text, expected):
assert text_proc.format_asr_text(text) == expected


def test_garbage_utterance():
with pytest.raises(text_proc.FormatASRError):
text_proc.format_asr_text("<NOISE> OH WHAT WAS THAT?")
7 changes: 1 addition & 6 deletions ultravox/tools/ds_tool/continuation.jinja
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
{% set formatted_text = text_proc.format_asr_text(text) %}
{% if formatted_text != "" %}
Continue the following text using less than 50 words:

{{ formatted_text }}
{% else %}
{% raise ValueError("The formatted text is empty.") %}
{% endif %}
{{ text_proc.format_asr_text(sentence) }}
206 changes: 152 additions & 54 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import dataclasses
import json
import math
import os
from typing import Any, Dict, List, Optional, Union
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union

import datasets
import jinja2
Expand Down Expand Up @@ -95,8 +97,6 @@ class TextGenerationTask:
max_tokens: int = 128
temperature: float = 0

template_failures: int = 0

def __post_init__(self):
# The OAI client is separate from the task to avoid pickling issues when multiprocessing.
global chat_client
Expand All @@ -116,18 +116,17 @@ def map_split(
writer_batch_size: int,
exclude_fields: List[str],
) -> datasets.Dataset:
print(f'Generating "{self.new_column_name}" with template:\n{self.template}')
# print(f'Generating "{self.new_column_name}" with template:\n{self.template}')
ds_mapped = ds_split.map(
lambda sample: self._map_sample(sample, set(exclude_fields)),
num_proc=num_proc,
writer_batch_size=writer_batch_size,
)
if self.template_failures == 0:
return ds_mapped
print("Finished generating text samples:", len(ds_mapped))

# Filter out samples where new_column_name is None
return ds_mapped.filter(
lambda sample: (True if sample[self.new_column_name] != None else False),
lambda sample: sample[self.new_column_name] != "",
num_proc=num_proc,
writer_batch_size=writer_batch_size,
)
Expand All @@ -143,20 +142,17 @@ def _map_sample(self, sample, exclude_fields):
rendered = jinja2.Template(
self.template, undefined=jinja2.StrictUndefined
).render(**filtered_sample, json_dump=json.dumps, text_proc=text_proc)

except text_proc.FormatASRError as e:
print(f"Format ASR Error {e}")
sample[self.new_column_name] = ""
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
return sample
except jinja2.TemplateError as e:
error_message = str(e)
if "The formatted text is empty." in error_message:
print("Formatted text is empty. Setting output to None.")
sample[self.new_column_name] = None
self.template_failures += 1
else:
print(f"Error rendering template: {e}")
print(f"template: {self.template}")
print(f"sample keys: {list(filtered_sample.keys())}")
raise ValueError(
f"Template rendering failed. Make sure all keys in the template exist in the sample."
) from e
print(f"Error rendering template: {e}")
print(f"template: {self.template}")
print(f"sample keys: {list(filtered_sample.keys())}")
raise ValueError(
f"Template rendering failed. Make sure all keys in the template exist in the sample."
) from e

if self.json_mode:
turns = yaml.safe_load(rendered)
Expand Down Expand Up @@ -189,7 +185,7 @@ def _map_sample(self, sample, exclude_fields):
class DatasetToolArgs:
# HF source dataset parameters
dataset_name: str = simple_parsing.field(alias="-d")
dataset_subset: Optional[str] = simple_parsing.field(default=None, alias="-S")
dataset_subset: Optional[str] = simple_parsing.field(None, alias="-S")
dataset_split: Optional[str] = simple_parsing.field(default=None, alias="-s")

# Local processing parameters
Expand All @@ -210,6 +206,10 @@ class DatasetToolArgs:
private: bool = simple_parsing.field(default=False)
token: Optional[str] = None

# Chunk processing parameters
max_chunk_split: int = simple_parsing.field(default=10)
chunk_split_threshold: int = simple_parsing.field(default=50000)

task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups(
{"tts": TtsTask, "textgen": TextGenerationTask}, # type: ignore
default_factory=TtsTask,
Expand All @@ -223,6 +223,129 @@ def __post_init__(self):
self.upload_split = self.dataset_split


class DatasetChunkProcessor:
args: DatasetToolArgs
cache_dir: str = ".cache/ds_tool/processed_datasets"
chunks_not_uploaded: List[Tuple[int, int]] = []
total_samples_processed: Dict[str, int] = defaultdict(int)

liPatrick marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, args: DatasetToolArgs):
self.args = args

def process_and_upload_split(self, split_name: str, ds_split: datasets.Dataset):
failed_chunk_ranges = self._split_chunk_and_upload(
split_name, ds_split, 0, len(ds_split)
)

while len(failed_chunk_ranges) > 0:
new_failed_ranges = []
for start, end in failed_chunk_ranges:
print(f"Retrying failed chunk range [{start}, {end})")
new_failed_ranges.extend(
self._split_chunk_and_upload(split_name, ds_split, start, end)
)
failed_chunk_ranges = new_failed_ranges
print(f"Could not upload chunks: {self.chunks_not_uploaded}")
print(
f"Finished processing and uploading all chunks for split {split_name}. Total samples processed: {self.total_samples_processed}"
)

def _split_chunk_and_upload(
self,
split_name: str,
ds_split: datasets.Dataset,
start_index: int,
end_index: int,
):
original_chunk_size = end_index - start_index
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
if original_chunk_size < self.args.chunk_split_threshold:
total_chunks = 1
chunk_size = original_chunk_size
print(
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
f"Chunk range [{start_index}, {end_index}) is too small to split further. Processing and uploading as a single chunk."
)
else:
total_chunks = self.args.max_chunk_split
chunk_size = math.ceil(original_chunk_size / total_chunks)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
failed_chunk_ranges = []
print(
f"Processing and uploading {total_chunks} chunks for range [{start_index}, {end_index}) with chunk size {chunk_size}"
)
for chunk_start in range(start_index, end_index, chunk_size):
chunk_end = min(chunk_start + chunk_size, end_index)

ds_chunk = ds_split.select(range(chunk_start, chunk_end))
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
ds_chunk_name = f"chunk-range-{chunk_start:09d}-{chunk_end:09d}.parquet"
ds_chunk_hub_path = os.path.join(
str(self.args.upload_subset), split_name, ds_chunk_name
)
ds_chunk_cache_path = os.path.join(
self.cache_dir,
self.args.dataset_name.replace("/", "__"),
str(self.args.upload_subset),
split_name,
ds_chunk_name,
)
try:
if os.path.exists(ds_chunk_cache_path):
print(
f"Skipping chunk {ds_chunk_name} as it has already been processed and uploaded."
)
ds_chunk_processed = datasets.Dataset.from_parquet(
ds_chunk_cache_path
)
else:
print(f"Processing chunk {ds_chunk_name}")
ds_chunk_processed = self._process(ds_chunk)
print(
"Finished processing chunk with length", len(ds_chunk_processed)
)
if len(ds_chunk_processed) > 0:
self._upload(ds_chunk_processed, ds_chunk_hub_path, split_name)
ds_chunk_processed.to_parquet(ds_chunk_cache_path)
else:
print(f"Chunk {ds_chunk_name} has 0 samples. Not uploading.")
self.total_samples_processed[split_name] += len(ds_chunk_processed)

except Exception as e:
# If the error is unsupported operand type(s) for -=: 'NoneType' and 'float',
# then the huggingface README needs to be updated to have the
# download_size, and dataset_size fields present under dataset_info (could be initalized to 0)
print(f"Failed to upload chunk {ds_chunk_name}: {e}. Retrying later.")
if total_chunks == 1:
print(
f"Finished processing and uploading 0/1 chunks for range [{start_index}, {end_index})"
)
self.chunks_not_uploaded.append((start_index, end_index))
return []
failed_chunk_ranges.append((chunk_start, chunk_end))
successful_chunks = self.args.max_chunk_split - len(failed_chunk_ranges)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
print(
f"Finished processing and uploading {successful_chunks}/{self.args.max_chunk_split} chunks for range [{start_index}, {end_index})"
)
return failed_chunk_ranges

def _process(self, ds_chunk: datasets.Dataset) -> datasets.Dataset:
return self.args.task.map_split(
ds_chunk,
self.args.num_workers,
self.args.writer_batch_size,
self.args.exclude_fields,
)

def _upload(self, ds_chunk_processed: datasets.Dataset, data_dir: str, split_name):
print(f"Uploading chunk to hub: {data_dir}")
hub_args: Dict[str, Any] = {
"config_name": self.args.upload_subset,
"token": self.args.token or os.environ.get("HF_TOKEN"),
"private": self.args.private,
"data_dir": data_dir,
"num_shards": self.args.num_shards,
"split": split_name,
}
ds_chunk_processed.push_to_hub(self.args.upload_name, **hub_args)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved


def main(args: DatasetToolArgs):
ds_name = args.dataset_name
print(f'Loading dataset "{ds_name}" for task {args.task}')
Expand All @@ -236,47 +359,22 @@ def main(args: DatasetToolArgs):
if len(data_dict) > 1 and args.upload_split:
raise ValueError("Cannot upload multiple splits to a single split")

for split, ds_split in data_dict.items():
ds_chunk_proc = DatasetChunkProcessor(args)

for split_name, ds_split in data_dict.items():
print(
f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {args.dataset_split}, containing {len(ds_split)} samples"
f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {split_name}, containing {len(ds_split)} samples"
)
if args.shuffle:
ds_split = ds_split.shuffle(seed=args.shuffle_seed)
if args.num_samples:
ds_split = ds_split.select(range(args.num_samples))
data_dict[split] = args.task.map_split(
ds_split,
args.num_workers,
args.writer_batch_size,
args.exclude_fields,
)

hub_args: Dict[str, Any] = {
"config_name": args.upload_subset or "default",
"token": args.token or os.environ.get("HF_TOKEN"),
"revision": args.upload_branch,
"private": args.private,
}

if args.num_shards is not None:
hub_args["num_shards"] = args.num_shards
ds_chunk_proc.process_and_upload_split(split_name, ds_split)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved

try:
if args.dataset_split:
data_dict[args.dataset_split].push_to_hub(
args.upload_name, split=args.upload_split, **hub_args
)
else:
data_dict.push_to_hub(args.upload_name, **hub_args)
except Exception as e:
print(f"Failed to push to hub: {e}")

# If the push fails or upload_name is not specified, save the data locally.
for split in data_dict.keys():
output_name = f"{split}-00000-of-00001.parquet"
data_dict[split].to_parquet(output_name)
print(f"Saved to {output_name}")
print(f"Sample {0} of {split}: {data_dict[split][0]}")
# Note: After running this script, you need to manually update the README.md file with the new dataset information.
# 1. Change the configs section path to point to upload_subset/split_name/**
# 2. Update the dataset split num_examples field to the total number of samples processed. This can be found in the logs as the final output.
liPatrick marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
Expand Down
Loading