Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chunking to ds_tool #97

Merged
merged 31 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
81bb206
Added chunking
liPatrick Aug 23, 2024
3ba9ff9
Dynamic chunking
liPatrick Aug 23, 2024
2ce70f3
Raise error in jinja template
liPatrick Aug 23, 2024
168275b
Format errors
liPatrick Aug 23, 2024
daa2ce9
Fix test
liPatrick Aug 23, 2024
31a8a13
Return sample when format is wrong
liPatrick Aug 23, 2024
b1f6d67
Remove template failures counter, doenst work with multi-proc
liPatrick Aug 26, 2024
4163628
Addressing comments
liPatrick Aug 26, 2024
255c46d
Handle text proc asr error in dataset.py
liPatrick Aug 26, 2024
f1236ce
removing extra prints
liPatrick Aug 26, 2024
eab6b80
Make process upload split recurisve
liPatrick Aug 26, 2024
4e42d58
Add more comments
liPatrick Aug 27, 2024
93c25bf
More comments
liPatrick Aug 27, 2024
1479d91
Use None instead of empty quotes. Type issue resolved
liPatrick Aug 27, 2024
e5ae10f
Chunked dataset subclass
liPatrick Aug 27, 2024
dce349e
Merge branch 'main' into patrick/chunking-ds_tool
liPatrick Aug 29, 2024
ded2f29
HF readme integration
liPatrick Sep 4, 2024
b7b2e80
format
liPatrick Sep 4, 2024
5f70e0f
Add dataset version to load
liPatrick Sep 4, 2024
653e3b0
Remove empty audio
liPatrick Sep 4, 2024
9c66031
Moved asr error try catch to get_transcribe_sample
liPatrick Sep 5, 2024
230f962
Remove total samples processed
liPatrick Sep 5, 2024
ac9bc03
Change continuation to text
liPatrick Sep 5, 2024
0adf5a6
Some fixes
liPatrick Sep 6, 2024
6ebadb0
Fix more bugs
liPatrick Sep 6, 2024
0d08712
Address comments
liPatrick Sep 6, 2024
d358768
Address comments
liPatrick Sep 6, 2024
3f8a0db
Fix import format
liPatrick Sep 6, 2024
f60b90e
Extra filter method
liPatrick Sep 6, 2024
cdc6dee
Check empty columns filter
liPatrick Sep 6, 2024
bddf33c
Add empty column check
liPatrick Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,11 @@ def _get_transcribe_sample(
row: transformers.BatchFeature,
tcol: str = "text",
tproc: Optional[Callable[[str], str]] = None,
) -> VoiceSample:
text = tproc(row[tcol]) if tproc else row[tcol]
) -> Optional[VoiceSample]:
try:
text = tproc(row[tcol]) if tproc else row[tcol]
except text_proc.FormatASRError:
return None
return self._make_sample(
self._get_transcribe_messages(text),
self._get_audio(row),
Expand Down Expand Up @@ -478,7 +481,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
)
self._init_dataset(dataset, 73)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)


Expand Down Expand Up @@ -534,7 +537,7 @@ class AnyInstructAnswerDataset(AnyInstructDataset):
def __init__(self, args: VoiceDatasetArgs) -> None:
super().__init__(args)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
chat = row["chat"]
return self._make_sample(
self._get_answer_messages(chat[0]["message"], chat[1]["message"]),
Expand All @@ -547,7 +550,7 @@ class AnyInstructInputDataset(AnyInstructDataset):
def __init__(self, args: VoiceDatasetArgs) -> None:
super().__init__(args)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
audio_transcript = row["chat"][0]["message"]
return self._make_sample(
self._get_transcribe_messages(audio_transcript),
Expand All @@ -560,7 +563,7 @@ class AnyInstructOutputDataset(AnyInstructDataset):
def __init__(self, args: VoiceDatasetArgs) -> None:
super().__init__(args)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
audio_transcript = row["chat"][1]["message"]
return self._make_sample(
self._get_transcribe_messages(audio_transcript),
Expand Down Expand Up @@ -589,7 +592,7 @@ def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:


class BoolQInputDataset(BoolQDataset):
def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tcol="question")


Expand Down Expand Up @@ -812,7 +815,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
ds = ds.shuffle(seed=self._args.shuffle_seed)
self._init_dataset(ds)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)


Expand All @@ -832,7 +835,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
)
self._init_dataset(dataset)

def _get_sample(self, row) -> VoiceSample:
def _get_sample(self, row) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)


Expand All @@ -852,7 +855,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
)
self._init_dataset(dataset)

def _get_sample(self, row) -> VoiceSample:
def _get_sample(self, row) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tcol="raw_text")


Expand All @@ -875,7 +878,7 @@ def __init__(self, args: VoiceDatasetArgs, lang: str = "en") -> None:
)
self._init_dataset(dataset)

def _get_sample(self, row) -> VoiceSample:
def _get_sample(self, row) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tcol="sentence")


Expand Down Expand Up @@ -977,7 +980,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
)
self._init_dataset(dataset)

def _get_sample(self, row) -> VoiceSample:
def _get_sample(self, row) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tcol="text")


Expand Down
4 changes: 2 additions & 2 deletions ultravox/data/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, n: int, args: Optional[datasets.VoiceDatasetArgs] = None):

self._init_dataset(FakeHuggingFaceIterableDataset(n), n)

def _get_sample(self, row: BatchFeature) -> datasets.VoiceSample:
def _get_sample(self, row: BatchFeature) -> Optional[datasets.VoiceSample]:
return self._get_transcribe_sample(row)


Expand All @@ -72,7 +72,7 @@ def __init__(
super().__init__(args or datasets.VoiceDatasetArgs())
self._init_dataset(FakeHuggingFaceIterableDataset(n), config.total_samples)

def _get_sample(self, row: BatchFeature) -> datasets.VoiceSample:
def _get_sample(self, row: BatchFeature) -> Optional[datasets.VoiceSample]:
return self._get_transcribe_sample(row)


Expand Down
13 changes: 10 additions & 3 deletions ultravox/data/text_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import nltk # needed for truecase
import truecase


class FormatASRError(ValueError):
pass


# only in master thread per node to avoid
# other threads overwriting the downloaded .zip
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
Expand All @@ -31,15 +36,17 @@ def format_asr_text(text: str) -> str:
remaining_words = []
for word in text.split():
if word in GIGASPEECH_GARBAGE_UTTERANCE_TAGS:
return ""
raise FormatASRError(f"Garbage utterance tag found: {word}")
if word in GIGASPEECH_PUNCTUATIONS:
word = GIGASPEECH_PUNCTUATIONS[word]
remaining_words.append(word)

text = " ".join(remaining_words)
text = truecase.get_true_case(text)

return text.strip()
text_stripped = text.strip()
if len(text_stripped) == 0:
raise FormatASRError("Empty text after processing")
return text_stripped


CONVERSATIONAL_FILLER = [
Expand Down
9 changes: 6 additions & 3 deletions ultravox/data/text_proc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
"I SEE LOTS OF PEOPLE HAVE DRONES HERE <COMMA> MAVERICK AS WELL <PERIOD> ",
"I see lots of people have drones here, maverick as well.",
),
# truecase messes with the case of special tags too, but we probably don't care about that
("<NOISE> OH WHAT WAS THAT?", ""),
],
)
def test_no_space_punctuation(text, expected):
def test_format_asr_text(text, expected):
assert text_proc.format_asr_text(text) == expected


def test_garbage_utterance():
with pytest.raises(text_proc.FormatASRError):
text_proc.format_asr_text("<NOISE> OH WHAT WAS THAT?")
Loading
Loading