diff --git a/ultravox/data/dataset_config.py b/ultravox/data/dataset_config.py index 3d6e3422..ea1f0a60 100644 --- a/ultravox/data/dataset_config.py +++ b/ultravox/data/dataset_config.py @@ -11,7 +11,7 @@ class DataDictConfig(BaseModel): name: Optional[str] = None splits: List[str] = dataclasses.field(default_factory=list) num_samples: Optional[int] = None - total_samples: int + total_samples: int = 1 weight: float = 1.0 streaming: bool = True user_template: str = "<|audio|>" diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index b136f839..aa3fb1b9 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -298,7 +298,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None: self._rng = np.random.default_rng(self._args.shuffle_seed) self._weight = 1.0 # the default weight for the dataset - def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 0) -> None: + def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None: self._dataset = dataset # Only required when using epochs when training dataset. self._estimated_length = estimated_length @@ -363,12 +363,12 @@ def __iter__(self): actual_length += 1 # If len(dataset) == 0 most likely the dataset is a validation dataset, # or the training is using max_steps instead of num_epochs. - if actual_length > len(self) and len(self) > 0: + if actual_length > len(self) and len(self) > 1: warnings.warn( f"The estimated length {self._estimated_length} has been exceeded for type {type(self._dataset)}. Make sure to update." ) - if actual_length != len(self) and len(self) > 0: + if actual_length != len(self) and len(self) > 1: warnings.warn( f"Mismatch between estimated length ({self._estimated_length}) and actual length ({actual_length}) for dataset of type {type(self._dataset)}. Make sure to update." ) @@ -484,7 +484,7 @@ def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample: # Making EmptyDataset a SizedIterableDataset to be compatible with using epochs during training. class EmptyDataset(SizedIterableDataset): - def __init__(self, estimated_length: int = 0) -> None: + def __init__(self, estimated_length: int = 1) -> None: self._estimated_length = estimated_length def __iter__(self): diff --git a/ultravox/training/train.py b/ultravox/training/train.py index 3884658a..05cea992 100644 --- a/ultravox/training/train.py +++ b/ultravox/training/train.py @@ -40,6 +40,7 @@ def fix_hyphens(arg: str): def prepare_dataset( + train_args: config_base.TrainConfig, dataset_names: List[str], data_args: datasets.VoiceDatasetArgs, processor: ultravox_processing.UltravoxProcessor, @@ -48,8 +49,14 @@ def prepare_dataset( num_samples: Optional[int] = None, include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training) ) -> datasets.SizedIterableDataset: - data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names] + # If we're using epochs to train, validate the dataset length is appropriate. + if train_args.max_steps == 0: + for ds in data_sets: + assert ( + len(ds) > 1 + ), f"Dataset {ds} has length {len(ds)} which is too short for epoch training" + interleave = datasets.InterleaveDataset(data_sets, stop_strategy=stop_strategy) ds_with_proc = data_processing.UltravoxDataproc( interleave, @@ -196,6 +203,7 @@ def train(args: config_base.TrainConfig): + [(f"text_{x}", [x]) for x in args.val_sets] ) train_dataset = prepare_dataset( + train_args=args, dataset_names=args.data_sets, train_on_inputs=args.train_on_inputs, stop_strategy=args.stop_strategy, @@ -226,6 +234,7 @@ def train(args: config_base.TrainConfig): val_ds_args_text.include_audio = False val_datasets = { k: prepare_dataset( + train_args=args, dataset_names=val_sets[k], train_on_inputs=args.train_on_inputs, stop_strategy=args.stop_strategy,