diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index aa3fb1b9..1ea40c95 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -448,8 +448,11 @@ def _get_transcribe_sample( row: transformers.BatchFeature, tcol: str = "text", tproc: Optional[Callable[[str], str]] = None, - ) -> VoiceSample: - text = tproc(row[tcol]) if tproc else row[tcol] + ) -> Optional[VoiceSample]: + try: + text = tproc(row[tcol]) if tproc else row[tcol] + except text_proc.FormatASRError: + return None return self._make_sample( self._get_transcribe_messages(text), self._get_audio(row), @@ -478,7 +481,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None: ) self._init_dataset(dataset, 73) - def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample: + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text) @@ -534,7 +537,7 @@ class AnyInstructAnswerDataset(AnyInstructDataset): def __init__(self, args: VoiceDatasetArgs) -> None: super().__init__(args) - def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample: + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: chat = row["chat"] return self._make_sample( self._get_answer_messages(chat[0]["message"], chat[1]["message"]), @@ -547,7 +550,7 @@ class AnyInstructInputDataset(AnyInstructDataset): def __init__(self, args: VoiceDatasetArgs) -> None: super().__init__(args) - def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample: + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: audio_transcript = row["chat"][0]["message"] return self._make_sample( self._get_transcribe_messages(audio_transcript), @@ -560,7 +563,7 @@ class AnyInstructOutputDataset(AnyInstructDataset): def __init__(self, args: VoiceDatasetArgs) -> None: super().__init__(args) - def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample: + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: audio_transcript = row["chat"][1]["message"] return self._make_sample( self._get_transcribe_messages(audio_transcript), @@ -589,7 +592,7 @@ def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: class BoolQInputDataset(BoolQDataset): - def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample: + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: return self._get_transcribe_sample(row, tcol="question") @@ -812,7 +815,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None: ds = ds.shuffle(seed=self._args.shuffle_seed) self._init_dataset(ds) - def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample: + def _get_sample(self, row: transformers.BatchFeature) -> Optional[VoiceSample]: return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text) @@ -832,7 +835,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None: ) self._init_dataset(dataset) - def _get_sample(self, row) -> VoiceSample: + def _get_sample(self, row) -> Optional[VoiceSample]: return self._get_transcribe_sample(row, tproc=text_proc.format_asr_text) @@ -852,7 +855,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None: ) self._init_dataset(dataset) - def _get_sample(self, row) -> VoiceSample: + def _get_sample(self, row) -> Optional[VoiceSample]: return self._get_transcribe_sample(row, tcol="raw_text") @@ -875,7 +878,7 @@ def __init__(self, args: VoiceDatasetArgs, lang: str = "en") -> None: ) self._init_dataset(dataset) - def _get_sample(self, row) -> VoiceSample: + def _get_sample(self, row) -> Optional[VoiceSample]: return self._get_transcribe_sample(row, tcol="sentence") @@ -977,7 +980,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None: ) self._init_dataset(dataset) - def _get_sample(self, row) -> VoiceSample: + def _get_sample(self, row) -> Optional[VoiceSample]: return self._get_transcribe_sample(row, tcol="text") diff --git a/ultravox/data/datasets_test.py b/ultravox/data/datasets_test.py index 2456ebf1..faacd3fa 100644 --- a/ultravox/data/datasets_test.py +++ b/ultravox/data/datasets_test.py @@ -56,7 +56,7 @@ def __init__(self, n: int, args: Optional[datasets.VoiceDatasetArgs] = None): self._init_dataset(FakeHuggingFaceIterableDataset(n), n) - def _get_sample(self, row: BatchFeature) -> datasets.VoiceSample: + def _get_sample(self, row: BatchFeature) -> Optional[datasets.VoiceSample]: return self._get_transcribe_sample(row) @@ -72,7 +72,7 @@ def __init__( super().__init__(args or datasets.VoiceDatasetArgs()) self._init_dataset(FakeHuggingFaceIterableDataset(n), config.total_samples) - def _get_sample(self, row: BatchFeature) -> datasets.VoiceSample: + def _get_sample(self, row: BatchFeature) -> Optional[datasets.VoiceSample]: return self._get_transcribe_sample(row) diff --git a/ultravox/data/text_proc.py b/ultravox/data/text_proc.py index 315ee15b..1e57db7c 100644 --- a/ultravox/data/text_proc.py +++ b/ultravox/data/text_proc.py @@ -7,6 +7,11 @@ import nltk # needed for truecase import truecase + +class FormatASRError(ValueError): + pass + + # only in master thread per node to avoid # other threads overwriting the downloaded .zip if int(os.environ.get("LOCAL_RANK", 0)) == 0: @@ -31,15 +36,17 @@ def format_asr_text(text: str) -> str: remaining_words = [] for word in text.split(): if word in GIGASPEECH_GARBAGE_UTTERANCE_TAGS: - return "" + raise FormatASRError(f"Garbage utterance tag found: {word}") if word in GIGASPEECH_PUNCTUATIONS: word = GIGASPEECH_PUNCTUATIONS[word] remaining_words.append(word) text = " ".join(remaining_words) text = truecase.get_true_case(text) - - return text.strip() + text_stripped = text.strip() + if len(text_stripped) == 0: + raise FormatASRError("Empty text after processing") + return text_stripped CONVERSATIONAL_FILLER = [ diff --git a/ultravox/data/text_proc_test.py b/ultravox/data/text_proc_test.py index f0e22c1e..294c9bc2 100644 --- a/ultravox/data/text_proc_test.py +++ b/ultravox/data/text_proc_test.py @@ -10,9 +10,12 @@ "I SEE LOTS OF PEOPLE HAVE DRONES HERE MAVERICK AS WELL ", "I see lots of people have drones here, maverick as well.", ), - # truecase messes with the case of special tags too, but we probably don't care about that - (" OH WHAT WAS THAT?", ""), ], ) -def test_no_space_punctuation(text, expected): +def test_format_asr_text(text, expected): assert text_proc.format_asr_text(text) == expected + + +def test_garbage_utterance(): + with pytest.raises(text_proc.FormatASRError): + text_proc.format_asr_text(" OH WHAT WAS THAT?") diff --git a/ultravox/tools/ds_tool/chunked_dataset.py b/ultravox/tools/ds_tool/chunked_dataset.py new file mode 100644 index 00000000..a8308ceb --- /dev/null +++ b/ultravox/tools/ds_tool/chunked_dataset.py @@ -0,0 +1,412 @@ +# mypy: ignore-errors +import fnmatch +import json +import math +import re +import warnings +from io import BytesIO +from pathlib import Path +from typing import Optional, Union + +import datasets +import huggingface_hub +from datasets.data_files import sanitize_patterns +from datasets.info import DatasetInfo +from datasets.info import DatasetInfosDict +from datasets.naming import _split_re +from datasets.splits import SplitDict +from datasets.splits import SplitInfo +from datasets.utils import logging +from datasets.utils.metadata import MetadataConfigs +from datasets.utils.py_utils import asdict +from datasets.utils.py_utils import glob_pattern_to_regex +from datasets.utils.py_utils import string_to_dict +from huggingface_hub.hf_api import RepoFile + +logger = logging.get_logger(__name__) +PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED = ( + "data/{split}-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.parquet" +) + + +class ChunkedDataset(datasets.Dataset): + @classmethod + def from_dataset(cls, dataset): + """ + Create a ChunkedDataset from an existing Dataset. + """ + obj = cls(dataset.data.table) + obj.__dict__.update(dataset.__dict__) + return obj + + def push_to_hub( + self, + repo_id: str, + config_name: str = "default", + set_default: Optional[bool] = None, + split: Optional[str] = None, + data_dir: Optional[str] = None, + commit_message: Optional[str] = None, + commit_description: Optional[str] = None, + private: Optional[bool] = False, + token: Optional[str] = None, + revision: Optional[str] = None, + branch="deprecated", + create_pr: Optional[bool] = False, + max_shard_size: Optional[Union[int, str]] = None, + num_shards: Optional[int] = None, + embed_external_files: bool = True, + ) -> huggingface_hub.CommitInfo: + """ + This overrides the push_to_hub method to work with chunked uploads. The old method assumed + each write was supposed to override the existing split data in the README, but this method will append to the + existing split values in the README (ie download_size, num_examples, etc). + """ + if config_name == "data": + raise ValueError( + "`config_name` cannot be 'data'. Please, choose another name for configuration." + ) + + if max_shard_size is not None and num_shards is not None: + raise ValueError( + "Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both." + ) + + if split is None: + split = str(self.split) if self.split is not None else "train" + + if not re.match(_split_re, split): + raise ValueError( + f"Split name should match '{_split_re}' but got '{split}'." + ) + + if branch != "deprecated": + warnings.warn( + "'branch' was deprecated in favor of 'revision' in version 2.15.0 and will be removed in 3.0.0.\n" + f"You can remove this warning by passing 'revision={branch}' instead.", + FutureWarning, + ) + revision = branch + + api = huggingface_hub.HfApi(endpoint=datasets.config.HF_ENDPOINT, token=token) + + repo_url = api.create_repo( + repo_id, + token=token, + repo_type="dataset", + private=private, + exist_ok=True, + ) + repo_id = repo_url.repo_id + + if revision is not None: + api.create_branch( + repo_id, + branch=revision, + token=token, + repo_type="dataset", + exist_ok=True, + ) + + if not data_dir: + data_dir = ( + config_name if config_name != "default" else "data" + ) # for backward compatibility + + additions, uploaded_size, dataset_nbytes = self._push_parquet_shards_to_hub( + repo_id=repo_id, + data_dir=data_dir, + split=split, + token=token, + revision=revision, + max_shard_size=max_shard_size, + num_shards=num_shards, + create_pr=create_pr, + embed_external_files=embed_external_files, + ) + + # Check if the repo already has a README.md and/or a dataset_infos.json to update them with the new split info (size and pattern) + # and delete old split shards (if they exist) + repo_with_dataset_card, repo_with_dataset_infos = False, False + deletions, deleted_size = [], 0 + repo_splits = [] # use a list to keep the order of the splits + repo_files_to_add = [addition.path_in_repo for addition in additions] + for repo_file in api.list_repo_tree( + repo_id=repo_id, + revision=revision, + repo_type="dataset", + token=token, + recursive=True, + ): + if not isinstance(repo_file, RepoFile): + continue + if repo_file.rfilename == datasets.config.REPOCARD_FILENAME: + repo_with_dataset_card = True + elif repo_file.rfilename == datasets.config.DATASETDICT_INFOS_FILENAME: + repo_with_dataset_infos = True + elif ( + repo_file.rfilename.startswith(f"{data_dir}/{split}-") + and repo_file.rfilename not in repo_files_to_add + ): + deletions.append( + huggingface_hub.CommitOperationDelete( + path_in_repo=repo_file.rfilename + ) + ) + deleted_size += repo_file.size + elif fnmatch.fnmatch( + repo_file.rfilename, + PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace( + "{split}", "*" + ), + ): + repo_split = string_to_dict( + repo_file.rfilename, + glob_pattern_to_regex( + PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED + ), + )["split"] + if repo_split not in repo_splits: + repo_splits.append(repo_split) + + organization, dataset_name = ( + repo_id.split("/") if "/" in repo_id else (None, repo_id) + ) + info_to_dump = self.info.copy() + info_to_dump.download_checksums = None + info_to_dump.download_size = uploaded_size + info_to_dump.dataset_size = dataset_nbytes + info_to_dump.size_in_bytes = uploaded_size + dataset_nbytes + info_to_dump.config_name = config_name + info_to_dump.splits = SplitDict( + { + split: SplitInfo( + split, + num_bytes=dataset_nbytes, + num_examples=len(self), + dataset_name=dataset_name, + ) + } + ) + # get the info from the README to update them + if repo_with_dataset_card: + dataset_card_path = api.hf_hub_download( + repo_id, + datasets.config.REPOCARD_FILENAME, + repo_type="dataset", + revision=revision, + ) + dataset_card = huggingface_hub.DatasetCard.load(Path(dataset_card_path)) + dataset_card_data = dataset_card.data + metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data) + dataset_infos: DatasetInfosDict = DatasetInfosDict.from_dataset_card_data( + dataset_card_data + ) + if dataset_infos and config_name in dataset_infos: + repo_info = dataset_infos[config_name] + else: + repo_info = None + # get the deprecated dataset_infos.json to update them + elif repo_with_dataset_infos: + dataset_card = None + dataset_card_data = huggingface_hub.DatasetCardData() + metadata_configs = MetadataConfigs() + dataset_infos_path = api.hf_hub_download( + repo_id, + datasets.config.DATASETDICT_INFOS_FILENAME, + repo_type="dataset", + revision=revision, + ) + with open(dataset_infos_path, encoding="utf-8") as f: + dataset_infos: dict = json.load(f) + dataset_info = ( + dataset_infos.get(config_name, None) if dataset_infos else None + ) + repo_info = ( + DatasetInfo.from_dict(dataset_info) if dataset_info else None + ) + else: + dataset_card = None + dataset_card_data = huggingface_hub.DatasetCardData() + metadata_configs = MetadataConfigs() + repo_info = None + # Update the total info to dump from existing info. + if repo_info is not None: + logger.info("Updating downloaded metadata with the new split.") + # MODIFIED: + # New Addition: + # Keep the old split info to update the new split info + old_split = repo_info.splits.get(split, SplitInfo()) + # MODIFIED: + # Old: + # if repo_info.splits and list(repo_info.splits) != [split]: + if repo_info.splits: + if self._info.features != repo_info.features: + raise ValueError( + f"Features of the new split don't match the features of the existing splits on the hub: {self._info.features} != {repo_info.features}" + ) + + repo_info.download_checksums = None + repo_info.download_size = (repo_info.download_size or 0) + uploaded_size + repo_info.dataset_size = (repo_info.dataset_size or 0) + dataset_nbytes + repo_info.size_in_bytes = ( + repo_info.download_size + repo_info.dataset_size + ) + repo_info.splits.pop(split, None) + # MODIFIED: + # Old: + # repo_info.splits[split] = SplitInfo( + # split, num_bytes=dataset_nbytes, num_examples=len(self), dataset_name=dataset_name + # ) + repo_info.splits[split] = SplitInfo( + split, + num_bytes=old_split.num_bytes + dataset_nbytes, + num_examples=old_split.num_examples + len(self), + dataset_name=dataset_name, + ) + info_to_dump = repo_info + # create the metadata configs if it was uploaded with push_to_hub before metadata configs existed + if not metadata_configs and repo_splits: + default_metadata_configs_to_dump = { + "data_files": [ + {"split": split, "path": f"data/{split}-*"} for split in repo_splits + ] + } + MetadataConfigs( + {"default": default_metadata_configs_to_dump} + ).to_dataset_card_data(dataset_card_data) + # update the metadata configs + if config_name in metadata_configs: + metadata_config = metadata_configs[config_name] + if "data_files" in metadata_config: + data_files_to_dump = sanitize_patterns(metadata_config["data_files"]) + else: + data_files_to_dump = {} + # add the new split + # MODIFIED: + # Old: + # data_files_to_dump[split] = [f"{data_dir}/{split}-*"] + data_files_to_dump[split] = [f"{config_name}/{split}/**"] + metadata_config_to_dump = { + "data_files": [ + { + "split": _split, + "path": _pattern[0] if len(_pattern) == 1 else _pattern, + } + for _split, _pattern in data_files_to_dump.items() + ] + } + else: + # MODIFIED: + # Old: + # metadata_config_to_dump = {"data_files": [{"split": split, "path": f"{data_dir}/{split}-*"}]} + metadata_config_to_dump = { + "data_files": [{"split": split, "path": f"{config_name}/{split}/**"}] + } + + if set_default and config_name != "default": + if metadata_configs: + default_config_name = metadata_configs.get_default_config_name() + if default_config_name == "default": + raise ValueError( + "There exists a configuration named 'default'. To set a different configuration as default, " + "rename the 'default' one first." + ) + else: + _ = metadata_configs[default_config_name].pop("default") + metadata_config_to_dump["default"] = True + # push to the deprecated dataset_infos.json + if repo_with_dataset_infos: + dataset_infos_path = api.hf_hub_download( + repo_id, + datasets.config.DATASETDICT_INFOS_FILENAME, + repo_type="dataset", + revision=revision, + ) + with open(dataset_infos_path, encoding="utf-8") as f: + dataset_infos: dict = json.load(f) + dataset_infos[config_name] = asdict(info_to_dump) + buffer = BytesIO() + buffer.write(json.dumps(dataset_infos, indent=4).encode("utf-8")) + additions.append( + huggingface_hub.CommitOperationAdd( + path_in_repo=datasets.config.DATASETDICT_INFOS_FILENAME, + path_or_fileobj=buffer, + ) + ) + # push to README + DatasetInfosDict({config_name: info_to_dump}).to_dataset_card_data( + dataset_card_data + ) + MetadataConfigs({config_name: metadata_config_to_dump}).to_dataset_card_data( + dataset_card_data + ) + dataset_card = ( + huggingface_hub.DatasetCard(f"---\n{dataset_card_data}\n---\n") + if dataset_card is None + else dataset_card + ) + additions.append( + huggingface_hub.CommitOperationAdd( + path_in_repo=datasets.config.REPOCARD_FILENAME, + path_or_fileobj=str(dataset_card).encode(), + ) + ) + + commit_message = ( + commit_message if commit_message is not None else "Upload dataset" + ) + if len(additions) <= datasets.config.UPLOADS_MAX_NUMBER_PER_COMMIT: + commit_info = api.create_commit( + repo_id, + operations=additions + deletions, + commit_message=commit_message, + commit_description=commit_description, + token=token, + repo_type="dataset", + revision=revision, + create_pr=create_pr, + ) + else: + logger.info( + f"Number of files to upload is larger than {datasets.config.UPLOADS_MAX_NUMBER_PER_COMMIT}. Splitting the push into multiple commits." + ) + num_commits = math.ceil( + len(additions) / datasets.config.UPLOADS_MAX_NUMBER_PER_COMMIT + ) + for i in range(0, num_commits): + operations = additions[ + i + * datasets.config.UPLOADS_MAX_NUMBER_PER_COMMIT : (i + 1) + * datasets.config.UPLOADS_MAX_NUMBER_PER_COMMIT + ] + (deletions if i == 0 else []) + commit_info = api.create_commit( + repo_id, + operations=operations, + commit_message=commit_message + + f" (part {i:05d}-of-{num_commits:05d})", + commit_description=commit_description, + token=token, + repo_type="dataset", + revision=revision, + create_pr=create_pr, + ) + logger.info( + f"Commit #{i+1} completed" + + ( + f" (still {num_commits - i - 1} to go)" + if num_commits - i - 1 + else "" + ) + + "." + ) + return commit_info + + +# Function to convert Dataset to ChunkedDataset +def convert_to_chunked_dataset(data) -> ChunkedDataset: + return ( + ChunkedDataset.from_dataset(data) + if isinstance(data, datasets.Dataset) + else data + ) diff --git a/ultravox/tools/ds_tool/continuation.jinja b/ultravox/tools/ds_tool/continuation.jinja index 9ed1447a..824d6b14 100644 --- a/ultravox/tools/ds_tool/continuation.jinja +++ b/ultravox/tools/ds_tool/continuation.jinja @@ -1,8 +1,3 @@ -{% set formatted_text = text_proc.format_asr_text(text) %} -{% if formatted_text != "" %} Continue the following text using less than 50 words: -{{ formatted_text }} -{% else %} -{% raise ValueError("The formatted text is empty.") %} -{% endif %} \ No newline at end of file +{{ text_proc.format_asr_text(text) }} \ No newline at end of file diff --git a/ultravox/tools/ds_tool/ds_tool.py b/ultravox/tools/ds_tool/ds_tool.py index 35ae0743..750f62e4 100644 --- a/ultravox/tools/ds_tool/ds_tool.py +++ b/ultravox/tools/ds_tool/ds_tool.py @@ -1,14 +1,19 @@ import dataclasses import json +import math import os -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import datasets import jinja2 import openai import simple_parsing import yaml +from tenacity import retry +from tenacity import stop_after_attempt +from tenacity import wait_fixed +import ultravox.tools.ds_tool.chunked_dataset as chunked_dataset from ultravox.data import text_proc from ultravox.tools.ds_tool import caching from ultravox.tools.ds_tool import tts @@ -95,8 +100,6 @@ class TextGenerationTask: max_tokens: int = 128 temperature: float = 0 - template_failures: int = 0 - def __post_init__(self): # The OAI client is separate from the task to avoid pickling issues when multiprocessing. global chat_client @@ -122,12 +125,10 @@ def map_split( num_proc=num_proc, writer_batch_size=writer_batch_size, ) - if self.template_failures == 0: - return ds_mapped # Filter out samples where new_column_name is None return ds_mapped.filter( - lambda sample: (True if sample[self.new_column_name] != None else False), + lambda sample: sample[self.new_column_name] is not None, num_proc=num_proc, writer_batch_size=writer_batch_size, ) @@ -143,20 +144,17 @@ def _map_sample(self, sample, exclude_fields): rendered = jinja2.Template( self.template, undefined=jinja2.StrictUndefined ).render(**filtered_sample, json_dump=json.dumps, text_proc=text_proc) - + except text_proc.FormatASRError as e: + print(f"Format ASR Error {e}. Filtering out sample.") + sample[self.new_column_name] = None + return sample except jinja2.TemplateError as e: - error_message = str(e) - if "The formatted text is empty." in error_message: - print("Formatted text is empty. Setting output to None.") - sample[self.new_column_name] = None - self.template_failures += 1 - else: - print(f"Error rendering template: {e}") - print(f"template: {self.template}") - print(f"sample keys: {list(filtered_sample.keys())}") - raise ValueError( - f"Template rendering failed. Make sure all keys in the template exist in the sample." - ) from e + print(f"Error rendering template: {e}") + print(f"template: {self.template}") + print(f"sample keys: {list(filtered_sample.keys())}") + raise ValueError( + f"Template rendering failed. Make sure all keys in the template exist in the sample." + ) from e if self.json_mode: turns = yaml.safe_load(rendered) @@ -191,6 +189,7 @@ class DatasetToolArgs: dataset_name: str = simple_parsing.field(alias="-d") dataset_subset: Optional[str] = simple_parsing.field(default=None, alias="-S") dataset_split: Optional[str] = simple_parsing.field(default=None, alias="-s") + dataset_version: Optional[str] = simple_parsing.field(default="main", alias="-v") # Local processing parameters shuffle: bool = simple_parsing.field(default=False) @@ -210,6 +209,15 @@ class DatasetToolArgs: private: bool = simple_parsing.field(default=False) token: Optional[str] = None + # Chunk processing parameters + num_chunks: int = simple_parsing.field(default=10) + chunk_split_threshold: int = simple_parsing.field(default=50000) + + # Columns that cannot be null + check_empty_columns: List[str] = simple_parsing.field( + default_factory=lambda: ["audio"] + ) + task: Union[TtsTask, TextGenerationTask] = simple_parsing.subgroups( {"tts": TtsTask, "textgen": TextGenerationTask}, # type: ignore default_factory=TtsTask, @@ -217,17 +225,152 @@ class DatasetToolArgs: ) def __post_init__(self): + if not self.dataset_subset: + self.dataset_subset = "default" if not self.upload_subset and self.dataset_subset: self.upload_subset = self.dataset_subset if self.dataset_split and not self.upload_split: self.upload_split = self.dataset_split +class DatasetChunkProcessor: + args: DatasetToolArgs + cache_dir: str = ".cache/ds_tool/processed_datasets" + chunks_not_uploaded: List[Tuple[int, int]] = [] + + def __init__(self, args: DatasetToolArgs): + self.args = args + + def process_and_upload_split_rescursive( + self, + split_name: str, + ds_split: datasets.Dataset, + start_index: int, + end_index: int, + ): + original_chunk_size = end_index - start_index + if original_chunk_size < self.args.chunk_split_threshold: + total_chunks = 1 + chunk_size = original_chunk_size + print( + f"Chunk range [{start_index}, {end_index}) is too small to split further. Processing and uploading as a single chunk." + ) + else: + total_chunks = self.args.num_chunks + chunk_size = math.ceil(original_chunk_size / total_chunks) + failed_chunk_ranges = [] + print( + f"Processing and uploading {total_chunks} chunks for range [{start_index}, {end_index}) with chunk size {chunk_size}" + ) + for chunk_start in range(start_index, end_index, chunk_size): + chunk_end = min(chunk_start + chunk_size, end_index) + + ds_chunk = ds_split.select(range(chunk_start, chunk_end)) + ds_chunk_name = f"chunk-range-{chunk_start:09d}-{chunk_end:09d}" + ds_chunk_hub_path = os.path.join( + str(self.args.upload_subset), split_name, ds_chunk_name + ) + ds_chunk_cache_path = os.path.join( + self.cache_dir, + self.args.dataset_name.replace("/", "__"), + str(self.args.upload_subset), + split_name, + ds_chunk_name + ".parquet", + ) + try: + if os.path.exists(ds_chunk_cache_path): + print( + f"Skipping chunk {ds_chunk_name} as it has already been processed and uploaded." + ) + else: + print(f"Processing chunk {ds_chunk_name}") + ds_chunk_processed = self._process(ds_chunk) + print( + "Finished processing chunk with length", len(ds_chunk_processed) + ) + if len(ds_chunk_processed) > 0: + # Note: The caching is after the upload to avoid caching failed upload chunks. + # Saved chunks indicate they have been uploaded to HF. + self._upload(ds_chunk_processed, ds_chunk_hub_path, split_name) + ds_chunk_processed.to_parquet(ds_chunk_cache_path) + else: + print(f"Chunk {ds_chunk_name} has 0 samples. Not uploading.") + + except Exception as e: + # If the error is unsupported operand type(s) for -=: 'NoneType' and 'float', + # then the huggingface README needs to be updated to have the + # download_size, and dataset_size fields present under dataset_info (could be initalized to 0) + print(f"Failed to upload chunk {ds_chunk_name}: {e}. Retrying later.") + if total_chunks == 1: + print( + f"Finished processing and uploading 0/1 chunks for range [{start_index}, {end_index})" + ) + self.chunks_not_uploaded.append((start_index, end_index)) + return None + failed_chunk_ranges.append((chunk_start, chunk_end)) + successful_chunks = self.args.num_chunks - len(failed_chunk_ranges) + print( + f"Finished processing and uploading {successful_chunks}/{self.args.num_chunks} chunks for range [{start_index}, {end_index})" + ) + if len(failed_chunk_ranges) > 0: + for start, end in failed_chunk_ranges: + print(f"Retrying failed chunk range [{start}, {end})") + self.process_and_upload_split_rescursive( + split_name, ds_split, start, end + ) + + print(f"Could not upload chunks: {self.chunks_not_uploaded}") + print(f"Finished processing and uploading all chunks for split {split_name}.") + + def _process(self, ds_chunk: datasets.Dataset) -> datasets.Dataset: + ds_mapped = self.args.task.map_split( + ds_chunk, + self.args.num_workers, + self.args.writer_batch_size, + self.args.exclude_fields, + ) + + check_empty_columns = self.args.check_empty_columns + if len(check_empty_columns) > 0: + return ds_mapped.filter( + lambda sample: all( + sample[column] is not None for column in check_empty_columns + ), + num_proc=self.args.num_workers, + writer_batch_size=self.args.writer_batch_size, + ) + else: + return ds_mapped + + @retry(wait=wait_fixed(3), stop=stop_after_attempt(3)) + def _upload(self, ds_chunk_processed: datasets.Dataset, data_dir: str, split_name): + print(f"Uploading chunk to hub: {data_dir}") + ds_split_chunked: chunked_dataset.ChunkedDataset = ( + chunked_dataset.convert_to_chunked_dataset(ds_chunk_processed) + ) + + hub_args: Dict[str, Any] = { + "config_name": self.args.upload_subset, + "token": self.args.token or os.environ.get("HF_TOKEN"), + "private": self.args.private, + "data_dir": data_dir, + "num_shards": self.args.num_shards, + "split": split_name, + } + assert isinstance(self.args.upload_name, str) + ds_split_chunked.push_to_hub(self.args.upload_name, **hub_args) + + def main(args: DatasetToolArgs): ds_name = args.dataset_name print(f'Loading dataset "{ds_name}" for task {args.task}') + download_config = datasets.DownloadConfig(num_proc=args.num_workers, max_retries=2) data_dict: datasets.DatasetDict = datasets.load_dataset( - ds_name, args.dataset_subset, split=args.dataset_split + ds_name, + args.dataset_subset, + split=args.dataset_split, + download_config=download_config, + revision=args.dataset_version, ) if isinstance(data_dict, datasets.Dataset): @@ -236,47 +379,20 @@ def main(args: DatasetToolArgs): if len(data_dict) > 1 and args.upload_split: raise ValueError("Cannot upload multiple splits to a single split") - for split, ds_split in data_dict.items(): + ds_chunk_proc = DatasetChunkProcessor(args) + + for split_name, ds_split in data_dict.items(): print( - f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {args.dataset_split}, containing {len(ds_split)} samples" + f"Processing dataset: {ds_name}, subset {args.dataset_subset}, split {split_name}, containing {len(ds_split)} samples" ) if args.shuffle: ds_split = ds_split.shuffle(seed=args.shuffle_seed) if args.num_samples: ds_split = ds_split.select(range(args.num_samples)) - data_dict[split] = args.task.map_split( - ds_split, - args.num_workers, - args.writer_batch_size, - args.exclude_fields, - ) - hub_args: Dict[str, Any] = { - "config_name": args.upload_subset or "default", - "token": args.token or os.environ.get("HF_TOKEN"), - "revision": args.upload_branch, - "private": args.private, - } - - if args.num_shards is not None: - hub_args["num_shards"] = args.num_shards - - try: - if args.dataset_split: - data_dict[args.dataset_split].push_to_hub( - args.upload_name, split=args.upload_split, **hub_args - ) - else: - data_dict.push_to_hub(args.upload_name, **hub_args) - except Exception as e: - print(f"Failed to push to hub: {e}") - - # If the push fails or upload_name is not specified, save the data locally. - for split in data_dict.keys(): - output_name = f"{split}-00000-of-00001.parquet" - data_dict[split].to_parquet(output_name) - print(f"Saved to {output_name}") - print(f"Sample {0} of {split}: {data_dict[split][0]}") + ds_chunk_proc.process_and_upload_split_rescursive( + split_name, ds_split, 0, len(ds_split) + ) if __name__ == "__main__":