Skip to content

Commit

Permalink
Assert dataset length when using epochs (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
liPatrick authored Sep 4, 2024
1 parent 7358f14 commit 74e3998
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ultravox/data/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DataDictConfig(BaseModel):
name: Optional[str] = None
splits: List[str] = dataclasses.field(default_factory=list)
num_samples: Optional[int] = None
total_samples: int
total_samples: int = 1
weight: float = 1.0
streaming: bool = True
user_template: str = "<|audio|>"
Expand Down
8 changes: 4 additions & 4 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
self._rng = np.random.default_rng(self._args.shuffle_seed)
self._weight = 1.0 # the default weight for the dataset

def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 0) -> None:
def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None:
self._dataset = dataset
# Only required when using epochs when training dataset.
self._estimated_length = estimated_length
Expand Down Expand Up @@ -363,12 +363,12 @@ def __iter__(self):
actual_length += 1
# If len(dataset) == 0 most likely the dataset is a validation dataset,
# or the training is using max_steps instead of num_epochs.
if actual_length > len(self) and len(self) > 0:
if actual_length > len(self) and len(self) > 1:
warnings.warn(
f"The estimated length {self._estimated_length} has been exceeded for type {type(self._dataset)}. Make sure to update."
)

if actual_length != len(self) and len(self) > 0:
if actual_length != len(self) and len(self) > 1:
warnings.warn(
f"Mismatch between estimated length ({self._estimated_length}) and actual length ({actual_length}) for dataset of type {type(self._dataset)}. Make sure to update."
)
Expand Down Expand Up @@ -484,7 +484,7 @@ def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:

# Making EmptyDataset a SizedIterableDataset to be compatible with using epochs during training.
class EmptyDataset(SizedIterableDataset):
def __init__(self, estimated_length: int = 0) -> None:
def __init__(self, estimated_length: int = 1) -> None:
self._estimated_length = estimated_length

def __iter__(self):
Expand Down
11 changes: 10 additions & 1 deletion ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def fix_hyphens(arg: str):


def prepare_dataset(
train_args: config_base.TrainConfig,
dataset_names: List[str],
data_args: datasets.VoiceDatasetArgs,
processor: ultravox_processing.UltravoxProcessor,
Expand All @@ -48,8 +49,14 @@ def prepare_dataset(
num_samples: Optional[int] = None,
include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training)
) -> datasets.SizedIterableDataset:

data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names]
# If we're using epochs to train, validate the dataset length is appropriate.
if train_args.max_steps == 0:
for ds in data_sets:
assert (
len(ds) > 1
), f"Dataset {ds} has length {len(ds)} which is too short for epoch training"

interleave = datasets.InterleaveDataset(data_sets, stop_strategy=stop_strategy)
ds_with_proc = data_processing.UltravoxDataproc(
interleave,
Expand Down Expand Up @@ -196,6 +203,7 @@ def train(args: config_base.TrainConfig):
+ [(f"text_{x}", [x]) for x in args.val_sets]
)
train_dataset = prepare_dataset(
train_args=args,
dataset_names=args.data_sets,
train_on_inputs=args.train_on_inputs,
stop_strategy=args.stop_strategy,
Expand Down Expand Up @@ -226,6 +234,7 @@ def train(args: config_base.TrainConfig):
val_ds_args_text.include_audio = False
val_datasets = {
k: prepare_dataset(
train_args=args,
dataset_names=val_sets[k],
train_on_inputs=args.train_on_inputs,
stop_strategy=args.stop_strategy,
Expand Down

0 comments on commit 74e3998

Please sign in to comment.