Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assert dataset length when using epochs #104

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ultravox/data/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DataDictConfig(BaseModel):
name: Optional[str] = None
splits: List[str] = dataclasses.field(default_factory=list)
num_samples: Optional[int] = None
total_samples: int
total_samples: int = 1
weight: float = 1.0
streaming: bool = True
user_template: str = "<|audio|>"
Expand Down
8 changes: 4 additions & 4 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def __init__(self, args: VoiceDatasetArgs) -> None:
self._rng = np.random.default_rng(self._args.shuffle_seed)
self._weight = 1.0 # the default weight for the dataset

def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 0) -> None:
def _init_dataset(self, dataset: data.Dataset, estimated_length: int = 1) -> None:
self._dataset = dataset
# Only required when using epochs when training dataset.
self._estimated_length = estimated_length
Expand Down Expand Up @@ -363,12 +363,12 @@ def __iter__(self):
actual_length += 1
# If len(dataset) == 0 most likely the dataset is a validation dataset,
# or the training is using max_steps instead of num_epochs.
if actual_length > len(self) and len(self) > 0:
if actual_length > len(self) and len(self) > 1:
warnings.warn(
f"The estimated length {self._estimated_length} has been exceeded for type {type(self._dataset)}. Make sure to update."
)

if actual_length != len(self) and len(self) > 0:
if actual_length != len(self) and len(self) > 1:
warnings.warn(
f"Mismatch between estimated length ({self._estimated_length}) and actual length ({actual_length}) for dataset of type {type(self._dataset)}. Make sure to update."
)
Expand Down Expand Up @@ -484,7 +484,7 @@ def _get_sample(self, row: transformers.BatchFeature) -> VoiceSample:

# Making EmptyDataset a SizedIterableDataset to be compatible with using epochs during training.
class EmptyDataset(SizedIterableDataset):
def __init__(self, estimated_length: int = 0) -> None:
def __init__(self, estimated_length: int = 1) -> None:
self._estimated_length = estimated_length

def __iter__(self):
Expand Down
11 changes: 10 additions & 1 deletion ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def fix_hyphens(arg: str):


def prepare_dataset(
train_args: config_base.TrainConfig,
dataset_names: List[str],
data_args: datasets.VoiceDatasetArgs,
processor: ultravox_processing.UltravoxProcessor,
Expand All @@ -48,8 +49,14 @@ def prepare_dataset(
num_samples: Optional[int] = None,
include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training)
) -> datasets.SizedIterableDataset:

data_sets = [datasets.create_dataset(ds, data_args) for ds in dataset_names]
# If we're using epochs to train, validate the dataset length is appropriate.
if train_args.max_steps == 0:
for ds in data_sets:
assert (
len(ds) > 1
), f"Dataset {ds} has length {len(ds)} which is too short for epoch training"

interleave = datasets.InterleaveDataset(data_sets, stop_strategy=stop_strategy)
ds_with_proc = data_processing.UltravoxDataproc(
interleave,
Expand Down Expand Up @@ -196,6 +203,7 @@ def train(args: config_base.TrainConfig):
+ [(f"text_{x}", [x]) for x in args.val_sets]
)
train_dataset = prepare_dataset(
train_args=args,
dataset_names=args.data_sets,
train_on_inputs=args.train_on_inputs,
stop_strategy=args.stop_strategy,
Expand Down Expand Up @@ -226,6 +234,7 @@ def train(args: config_base.TrainConfig):
val_ds_args_text.include_audio = False
val_datasets = {
k: prepare_dataset(
train_args=args,
dataset_names=val_sets[k],
train_on_inputs=args.train_on_inputs,
stop_strategy=args.stop_strategy,
Expand Down
Loading