Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chunking to ds_tool #97

Merged
merged 31 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
81bb206
Added chunking
liPatrick Aug 23, 2024
3ba9ff9
Dynamic chunking
liPatrick Aug 23, 2024
2ce70f3
Raise error in jinja template
liPatrick Aug 23, 2024
168275b
Format errors
liPatrick Aug 23, 2024
daa2ce9
Fix test
liPatrick Aug 23, 2024
31a8a13
Return sample when format is wrong
liPatrick Aug 23, 2024
b1f6d67
Remove template failures counter, doenst work with multi-proc
liPatrick Aug 26, 2024
4163628
Addressing comments
liPatrick Aug 26, 2024
255c46d
Handle text proc asr error in dataset.py
liPatrick Aug 26, 2024
f1236ce
removing extra prints
liPatrick Aug 26, 2024
eab6b80
Make process upload split recurisve
liPatrick Aug 26, 2024
4e42d58
Add more comments
liPatrick Aug 27, 2024
93c25bf
More comments
liPatrick Aug 27, 2024
1479d91
Use None instead of empty quotes. Type issue resolved
liPatrick Aug 27, 2024
e5ae10f
Chunked dataset subclass
liPatrick Aug 27, 2024
dce349e
Merge branch 'main' into patrick/chunking-ds_tool
liPatrick Aug 29, 2024
ded2f29
HF readme integration
liPatrick Sep 4, 2024
b7b2e80
format
liPatrick Sep 4, 2024
5f70e0f
Add dataset version to load
liPatrick Sep 4, 2024
653e3b0
Remove empty audio
liPatrick Sep 4, 2024
9c66031
Moved asr error try catch to get_transcribe_sample
liPatrick Sep 5, 2024
230f962
Remove total samples processed
liPatrick Sep 5, 2024
ac9bc03
Change continuation to text
liPatrick Sep 5, 2024
0adf5a6
Some fixes
liPatrick Sep 6, 2024
6ebadb0
Fix more bugs
liPatrick Sep 6, 2024
0d08712
Address comments
liPatrick Sep 6, 2024
d358768
Address comments
liPatrick Sep 6, 2024
3f8a0db
Fix import format
liPatrick Sep 6, 2024
f60b90e
Extra filter method
liPatrick Sep 6, 2024
cdc6dee
Check empty columns filter
liPatrick Sep 6, 2024
bddf33c
Add empty column check
liPatrick Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions ultravox/data/text_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
import nltk # needed for truecase
import truecase


class GarbageUtteranceError(ValueError):
pass


class EmptyTranscriptError(ValueError):
pass


# only in master thread per node to avoid
# other threads overwriting the downloaded .zip
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
Expand All @@ -31,15 +40,17 @@ def format_asr_text(text: str) -> str:
remaining_words = []
for word in text.split():
if word in GIGASPEECH_GARBAGE_UTTERANCE_TAGS:
return ""
raise GarbageUtteranceError(f"Garbage utterance tag found: {word}")
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
if word in GIGASPEECH_PUNCTUATIONS:
word = GIGASPEECH_PUNCTUATIONS[word]
remaining_words.append(word)

text = " ".join(remaining_words)
text = truecase.get_true_case(text)

return text.strip()
text_stripped = text.strip()
if len(text_stripped) == 0:
raise EmptyTranscriptError("Empty transcript after processing")
return text_stripped


CONVERSATIONAL_FILLER = [
Expand Down
9 changes: 6 additions & 3 deletions ultravox/data/text_proc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
"I SEE LOTS OF PEOPLE HAVE DRONES HERE <COMMA> MAVERICK AS WELL <PERIOD> ",
"I see lots of people have drones here, maverick as well.",
),
# truecase messes with the case of special tags too, but we probably don't care about that
("<NOISE> OH WHAT WAS THAT?", ""),
],
)
def test_no_space_punctuation(text, expected):
def test_format_asr_text(text, expected):
assert text_proc.format_asr_text(text) == expected


def test_garbage_utterance():
with pytest.raises(text_proc.GarbageUtteranceError):
text_proc.format_asr_text("<NOISE> OH WHAT WAS THAT?")
7 changes: 1 addition & 6 deletions ultravox/tools/ds_tool/continuation.jinja
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
{% set formatted_text = text_proc.format_asr_text(text) %}
{% if formatted_text != "" %}
Continue the following text using less than 50 words:

{{ formatted_text }}
{% else %}
{% raise ValueError("The formatted text is empty.") %}
{% endif %}
{{ text_proc.format_asr_text(text) }}
186 changes: 143 additions & 43 deletions ultravox/tools/ds_tool/ds_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import dataclasses
import json
import math
import os
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import datasets
import jinja2
Expand Down Expand Up @@ -116,7 +117,7 @@ def map_split(
writer_batch_size: int,
exclude_fields: List[str],
) -> datasets.Dataset:
print(f'Generating "{self.new_column_name}" with template:\n{self.template}')
# print(f'Generating "{self.new_column_name}" with template:\n{self.template}')
ds_mapped = ds_split.map(
lambda sample: self._map_sample(sample, set(exclude_fields)),
num_proc=num_proc,
Expand All @@ -127,7 +128,7 @@ def map_split(

# Filter out samples where new_column_name is None
return ds_mapped.filter(
lambda sample: (True if sample[self.new_column_name] != None else False),
lambda sample: sample[self.new_column_name] != None,
num_proc=num_proc,
writer_batch_size=writer_batch_size,
)
Expand All @@ -143,20 +144,26 @@ def _map_sample(self, sample, exclude_fields):
rendered = jinja2.Template(
self.template, undefined=jinja2.StrictUndefined
).render(**filtered_sample, json_dump=json.dumps, text_proc=text_proc)

except jinja2.TemplateError as e:
error_message = str(e)
if "The formatted text is empty." in error_message:
except Exception as e:
if isinstance(e, text_proc.GarbageUtteranceError):
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
print("Formatted text is empty. Setting output to None.")
sample[self.new_column_name] = None
self.template_failures += 1
else:
return sample
elif isinstance(e, text_proc.EmptyTranscriptError):
print("Empty transcript after processing. Setting output to None.")
sample[self.new_column_name] = None
self.template_failures += 1
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
return sample
elif isinstance(e, jinja2.TemplateError):
print(f"Error rendering template: {e}")
print(f"template: {self.template}")
print(f"sample keys: {list(filtered_sample.keys())}")
raise ValueError(
f"Template rendering failed. Make sure all keys in the template exist in the sample."
) from e
else:
raise e
liPatrick marked this conversation as resolved.
Show resolved Hide resolved

if self.json_mode:
turns = yaml.safe_load(rendered)
Expand Down Expand Up @@ -189,7 +196,7 @@ def _map_sample(self, sample, exclude_fields):
class DatasetToolArgs:
# HF source dataset parameters
dataset_name: str = simple_parsing.field(alias="-d")
dataset_subset: Optional[str] = simple_parsing.field(default=None, alias="-S")
dataset_subset: Optional[str] = simple_parsing.field(default="default", alias="-S")
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
dataset_split: Optional[str] = simple_parsing.field(default=None, alias="-s")

# Local processing parameters
Expand All @@ -210,6 +217,10 @@ class DatasetToolArgs:
private: bool = simple_parsing.field(default=False)
token: Optional[str] = None

# Chunk processing parameters
max_chunk_split: int = simple_parsing.field(default=10)
chunk_split_threshold: int = simple_parsing.field(default=50000)

task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups(
{"tts": TtsTask, "textgen": TextGenerationTask}, # type: ignore
default_factory=TtsTask,
Expand All @@ -223,6 +234,124 @@ def __post_init__(self):
self.upload_split = self.dataset_split


class DatasetChunkProcessor:
args: DatasetToolArgs
cache_dir: str = ".cache/ds_tool/processed_datasets"
chunks_not_uploaded: List[Tuple[int, int]] = []
total_samples_processed: Dict[str, int] = {}

liPatrick marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, args: DatasetToolArgs):
self.args = args

def process_and_upload_split(self, split_name: str, ds_split: datasets.Dataset):
failed_chunk_ranges = self._dynamic_chunk(
split_name, ds_split, 0, len(ds_split)
)

while len(failed_chunk_ranges) > 0:
new_failed_ranges = []
for start, end in failed_chunk_ranges:
print(f"Retrying failed chunk range [{start}, {end})")
new_failed_ranges.extend(
self._dynamic_chunk(split_name, ds_split, start, end)
)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
failed_chunk_ranges = new_failed_ranges
print(f"Could not upload chunks: {self.chunks_not_uploaded}")
print(
f"Finished processing and uploading all chunks for split {split_name}. Total samples processed: {self.total_samples_processed}"
)

def _dynamic_chunk(
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
self,
split_name: str,
ds_split: datasets.Dataset,
start_index: int,
end_index: int,
):
original_chunk_size = end_index - start_index
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
if original_chunk_size < self.args.chunk_split_threshold:
total_chunks = 1
chunk_size = original_chunk_size
print(
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
f"Chunk range [{start_index}, {end_index}) is too small to split further. Processing and uploading as a single chunk."
)
else:
total_chunks = self.args.max_chunk_split
chunk_size = math.ceil(original_chunk_size / total_chunks)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
failed_chunk_ranges = []
print(
f"Processing and uploading {total_chunks} chunks for range [{start_index}, {end_index}) with chunk size {chunk_size}"
)
for i in range(start_index, end_index, chunk_size):
chunk_start = i
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
chunk_end = min(i + chunk_size, start_index + original_chunk_size)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved

ds_chunk = ds_split.select(range(chunk_start, chunk_end))
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
ds_chunk_name = f"chunk-range-{chunk_start:09d}-{chunk_end:09d}.parquet"
ds_chunk_hub_path = os.path.join(
str(self.args.upload_subset), split_name, ds_chunk_name
)
ds_chunk_cache_path = os.path.join(
self.cache_dir,
self.args.dataset_name.replace("/", "__"),
str(self.args.upload_subset),
split_name,
ds_chunk_name,
)
try:
if os.path.exists(ds_chunk_cache_path):
print(
f"Skipping chunk {ds_chunk_name} as it has already been processed and uploaded."
)
ds_chunk_processed = datasets.Dataset.from_parquet(
ds_chunk_cache_path
)
else:
print(f"Processing chunk {ds_chunk_name}")
ds_chunk_processed = self._process(ds_chunk)
self._upload(ds_chunk_processed, ds_chunk_hub_path, split_name)
ds_chunk_processed.to_parquet(ds_chunk_cache_path)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
self.total_samples_processed[split_name] = (
self.total_samples_processed.get(split_name, 0)
+ len(ds_chunk_processed)
)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved

except Exception as e:
print(f"Failed to upload chunk {ds_chunk_name}: {e}. Retrying later.")
if total_chunks == 1:
print(
f"Finished processing and uploading 0/1 chunks for range [{start_index}, {end_index})"
)
self.chunks_not_uploaded.append((start_index, end_index))
return []
failed_chunk_ranges.append((chunk_start, chunk_end))
successful_chunks = self.args.max_chunk_split - len(failed_chunk_ranges)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
print(
f"Finished processing and uploading {successful_chunks}/{self.args.max_chunk_split} chunks for range [{start_index}, {end_index})"
)
return failed_chunk_ranges

def _process(self, ds_chunk: datasets.Dataset) -> datasets.Dataset:
return self.args.task.map_split(
ds_chunk,
self.args.num_workers,
self.args.writer_batch_size,
self.args.exclude_fields,
)

def _upload(self, ds_chunk_processed: datasets.Dataset, data_dir: str, split_name):
print(f"Uploading chunk to hub: {data_dir}")
hub_args: Dict[str, Any] = {
"config_name": self.args.dataset_subset,
"token": self.args.token or os.environ.get("HF_TOKEN"),
"private": self.args.private,
"data_dir": data_dir,
"num_shards": self.args.num_shards,
"split": split_name,
}
ds_chunk_processed.push_to_hub(self.args.upload_name, **hub_args)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved


def main(args: DatasetToolArgs):
ds_name = args.dataset_name
print(f'Loading dataset "{ds_name}" for task {args.task}')
Expand All @@ -236,47 +365,18 @@ def main(args: DatasetToolArgs):
if len(data_dict) > 1 and args.upload_split:
raise ValueError("Cannot upload multiple splits to a single split")

for split, ds_split in data_dict.items():
ds_chunk_proc = DatasetChunkProcessor(args)

for split_name, ds_split in data_dict.items():
print(
f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {args.dataset_split}, containing {len(ds_split)} samples"
f"Processing dataset: {args.dataset_name}, subset {args.dataset_subset}, split {split_name}, containing {len(ds_split)} samples"
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
)
if args.shuffle:
ds_split = ds_split.shuffle(seed=args.shuffle_seed)
if args.num_samples:
ds_split = ds_split.select(range(args.num_samples))
data_dict[split] = args.task.map_split(
ds_split,
args.num_workers,
args.writer_batch_size,
args.exclude_fields,
)

hub_args: Dict[str, Any] = {
"config_name": args.upload_subset or "default",
"token": args.token or os.environ.get("HF_TOKEN"),
"revision": args.upload_branch,
"private": args.private,
}

if args.num_shards is not None:
hub_args["num_shards"] = args.num_shards

try:
if args.dataset_split:
data_dict[args.dataset_split].push_to_hub(
args.upload_name, split=args.upload_split, **hub_args
)
else:
data_dict.push_to_hub(args.upload_name, **hub_args)
except Exception as e:
print(f"Failed to push to hub: {e}")

# If the push fails or upload_name is not specified, save the data locally.
for split in data_dict.keys():
output_name = f"{split}-00000-of-00001.parquet"
data_dict[split].to_parquet(output_name)
print(f"Saved to {output_name}")
print(f"Sample {0} of {split}: {data_dict[split][0]}")
ds_chunk_proc.process_and_upload_split(split_name, ds_split)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == "__main__":
Expand Down
Loading