Skip to content

Commit

Permalink
Add chunking to ds_tool (#97)
Browse files Browse the repository at this point in the history
* Added chunking

* Dynamic chunking

* Raise error in jinja template

* Format errors

* Fix test

* Return sample when format is wrong

* Remove template failures counter, doenst work with multi-proc

* Addressing comments

* Handle text proc asr error in dataset.py

* removing extra prints

* Make process upload split recurisve

* Add more comments

* More comments

* Use None instead of empty quotes. Type issue resolved

* Chunked dataset subclass

* HF readme integration

* format

* Add dataset version to load

* Remove empty audio

* Moved asr error try catch to get_transcribe_sample

* Remove total samples processed

* Change continuation to text

* Some fixes

* Fix more bugs

* Address comments

* Address comments

* Fix import format

* Extra filter method

* Check empty columns filter

* Add empty column check
  • Loading branch information
liPatrick authored Sep 6, 2024
1 parent 74e3998 commit 96a17f5
Show file tree
Hide file tree
Showing 7 changed files with 616 additions and 80 deletions.
27 changes: 15 additions & 12 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,8 +448,11 @@ def _get_transcribe_sample(
row: transformers.BatchFeature,
tcol: str = "text",
tproc: Optional[Callable[[str], str]] = None,
) -> VoiceSample:
text = tproc(row[tcol]) if tproc else row[tcol]
) -> Optional[VoiceSample]:
try:
text = tproc(row[tcol]) if tproc else row[tcol]
except text_proc.FormatASRError:
return None
return self._make_sample(
self._get_transcribe_messages(text),
self._get_audio(row),
Expand Down Expand Up @@ -478,7 +481,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
)
self._init_dataset(dataset, 73)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)


Expand Down Expand Up @@ -534,7 +537,7 @@ class AnyInstructAnswerDataset(AnyInstructDataset):
def __init__(self, args: VoiceDatasetArgs) -> None:
super().__init__(args)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
chat = row["chat"]
return self._make_sample(
self._get_answer_messages(chat[0]["message"], chat[1]["message"]),
Expand All @@ -547,7 +550,7 @@ class AnyInstructInputDataset(AnyInstructDataset):
def __init__(self, args: VoiceDatasetArgs) -> None:
super().__init__(args)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
audio_transcript = row["chat"][0]["message"]
return self._make_sample(
self._get_transcribe_messages(audio_transcript),
Expand All @@ -560,7 +563,7 @@ class AnyInstructOutputDataset(AnyInstructDataset):
def __init__(self, args: VoiceDatasetArgs) -> None:
super().__init__(args)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
audio_transcript = row["chat"][1]["message"]
return self._make_sample(
self._get_transcribe_messages(audio_transcript),
Expand Down Expand Up @@ -589,7 +592,7 @@ def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:


class BoolQInputDataset(BoolQDataset):
def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tcol="question")


Expand Down Expand Up @@ -812,7 +815,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
ds = ds.shuffle(seed=self._args.shuffle_seed)
self._init_dataset(ds)

def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:
def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)


Expand All @@ -832,7 +835,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
)
self._init_dataset(dataset)

def _get_sample(self, row) -> VoiceSample:
def _get_sample(self, row) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text)


Expand All @@ -852,7 +855,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
)
self._init_dataset(dataset)

def _get_sample(self, row) -> VoiceSample:
def _get_sample(self, row) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tcol="raw_text")


Expand All @@ -875,7 +878,7 @@ def __init__(self, args: VoiceDatasetArgs, lang: str = "en") -> None:
)
self._init_dataset(dataset)

def _get_sample(self, row) -> VoiceSample:
def _get_sample(self, row) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tcol="sentence")


Expand Down Expand Up @@ -977,7 +980,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
)
self._init_dataset(dataset)

def _get_sample(self, row) -> VoiceSample:
def _get_sample(self, row) -> Optional[VoiceSample]:
return self._get_transcribe_sample(row, tcol="text")


Expand Down
4 changes: 2 additions & 2 deletions ultravox/data/datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, n: int, args: Optional[datasets.VoiceDatasetArgs] = None):

self._init_dataset(FakeHuggingFaceIterableDataset(n), n)

def _get_sample(self, row: BatchFeature) -> datasets.VoiceSample:
def _get_sample(self, row: BatchFeature) -> Optional[datasets.VoiceSample]:
return self._get_transcribe_sample(row)


Expand All @@ -72,7 +72,7 @@ def __init__(
super().__init__(args or datasets.VoiceDatasetArgs())
self._init_dataset(FakeHuggingFaceIterableDataset(n), config.total_samples)

def _get_sample(self, row: BatchFeature) -> datasets.VoiceSample:
def _get_sample(self, row: BatchFeature) -> Optional[datasets.VoiceSample]:
return self._get_transcribe_sample(row)


Expand Down
13 changes: 10 additions & 3 deletions ultravox/data/text_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import nltk # needed for truecase
import truecase


class FormatASRError(ValueError):
pass


# only in master thread per node to avoid
# other threads overwriting the downloaded .zip
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
Expand All @@ -31,15 +36,17 @@ def format_asr_text(text: str) -> str:
remaining_words = []
for word in text.split():
if word in GIGASPEECH_GARBAGE_UTTERANCE_TAGS:
return ""
raise FormatASRError(f"Garbage utterance tag found: {word}")
if word in GIGASPEECH_PUNCTUATIONS:
word = GIGASPEECH_PUNCTUATIONS[word]
remaining_words.append(word)

text = " ".join(remaining_words)
text = truecase.get_true_case(text)

return text.strip()
text_stripped = text.strip()
if len(text_stripped) == 0:
raise FormatASRError("Empty text after processing")
return text_stripped


CONVERSATIONAL_FILLER = [
Expand Down
9 changes: 6 additions & 3 deletions ultravox/data/text_proc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
"I SEE LOTS OF PEOPLE HAVE DRONES HERE <COMMA> MAVERICK AS WELL <PERIOD> ",
"I see lots of people have drones here, maverick as well.",
),
# truecase messes with the case of special tags too, but we probably don't care about that
("<NOISE> OH WHAT WAS THAT?", ""),
],
)
def test_no_space_punctuation(text, expected):
def test_format_asr_text(text, expected):
assert text_proc.format_asr_text(text) == expected


def test_garbage_utterance():
with pytest.raises(text_proc.FormatASRError):
text_proc.format_asr_text("<NOISE> OH WHAT WAS THAT?")
Loading

0 comments on commit 96a17f5

Please sign in to comment.