diff --git a/src/dalle_mini/data.py b/src/dalle_mini/data.py index 78488765d..8ce6ca9d4 100644 --- a/src/dalle_mini/data.py +++ b/src/dalle_mini/data.py @@ -1,6 +1,7 @@ import random from dataclasses import dataclass, field from functools import partial +from pathlib import Path import jax import jax.numpy as jnp @@ -34,8 +35,10 @@ class Dataset: max_clip_score: float = None filter_column: str = None filter_value: str = None + multi_eval_ds: bool = False train_dataset: Dataset = field(init=False) eval_dataset: Dataset = field(init=False) + other_eval_datasets: list = field(init=False) rng_dataset: jnp.ndarray = field(init=False) multi_hosts: bool = field(init=False) @@ -75,6 +78,21 @@ def __post_init__(self): else: data_files = None + # multiple validation datasets + if self.multi_eval_ds: + assert Path( + self.dataset_repo_or_path + ).is_dir(), f"{self.dataset_repo_or_path} is not a directory, required for multi_eval_ds" + data_files = { + split.name: [str(f) for f in split.glob("*.parquet")] + for split in Path(self.dataset_repo_or_path).glob("*") + } + # rename "valid" to "validation" if present for consistency + if "valid" in data_files: + data_files["validation"] = data_files["valid"] + del data_files["valid"] + self.dataset_repo_or_path = "parquet" + # load dataset dataset = load_dataset( self.dataset_repo_or_path, @@ -102,6 +120,11 @@ def __post_init__(self): if self.streaming else self.eval_dataset.select(range(self.max_eval_samples)) ) + # other eval datasets + other_eval_splits = dataset.keys() - {"train", "validation"} + self.other_eval_datasets = { + split: dataset[split] for split in other_eval_splits + } def preprocess(self, tokenizer, config): # get required config variables @@ -143,6 +166,20 @@ def preprocess(self, tokenizer, config): ) ), ) + if hasattr(self, "other_eval_datasets"): + self.other_eval_datasets = { + split: ( + ds.filter(partial_filter_function) + if self.streaming + else ds.filter( + partial_filter_function, + num_proc=self.preprocessing_num_workers, + load_from_cache_file=not self.overwrite_cache, + desc="Filtering datasets", + ) + ) + for split, ds in self.other_eval_datasets.items() + } # normalize text if normalize_text: @@ -168,6 +205,20 @@ def preprocess(self, tokenizer, config): ) ), ) + if hasattr(self, "other_eval_datasets"): + self.other_eval_datasets = { + split: ( + ds.map(partial_normalize_function) + if self.streaming + else ds.map( + partial_normalize_function, + num_proc=self.preprocessing_num_workers, + load_from_cache_file=not self.overwrite_cache, + desc="Normalizing datasets", + ) + ) + for split, ds in self.other_eval_datasets.items() + } # blank captions if self.blank_caption_prob: @@ -225,6 +276,29 @@ def preprocess(self, tokenizer, config): ) ), ) + if hasattr(self, "other_eval_datasets"): + self.other_eval_datasets = { + split: ( + ds.map( + partial_preprocess_function, + batched=True, + remove_columns=[ + self.text_column, + self.encoding_column, + ], + ) + if self.streaming + else ds.map( + partial_preprocess_function, + batched=True, + remove_columns=getattr(ds, "column_names"), + num_proc=self.preprocessing_num_workers, + load_from_cache_file=not self.overwrite_cache, + desc="Preprocessing datasets", + ) + ) + for split, ds in self.other_eval_datasets.items() + } def dataloader(self, split, batch_size, epoch=None): def _dataloader_datasets_non_streaming( @@ -283,7 +357,7 @@ def _dataloader_datasets_streaming( elif split == "eval": ds = self.eval_dataset else: - raise ValueError(f'split must be "train" or "eval", got {split}') + ds = self.other_eval_datasets[split] if self.streaming: return _dataloader_datasets_streaming(ds, epoch) diff --git a/tools/train/train.py b/tools/train/train.py index 581557e35..abc66fe6c 100644 --- a/tools/train/train.py +++ b/tools/train/train.py @@ -250,6 +250,12 @@ class DataTrainingArguments: default=None, metadata={"help": "Class value to be kept during filtering."}, ) + multi_eval_ds: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to look for multiple validation datasets (local support only)." + }, + ) max_train_samples: Optional[int] = field( default=None, metadata={ @@ -1383,62 +1389,73 @@ def run_evaluation(): # ======================== Evaluating ============================== if training_args.do_eval: start_eval_time = time.perf_counter() - eval_loader = dataset.dataloader( - "eval", - eval_batch_size_per_step - * max(1, training_args.mp_devices // jax.local_device_count()), - ) - eval_steps = ( - len_eval_dataset // eval_batch_size_per_step - if len_eval_dataset is not None - else None + # get validation datasets + val_datasets = list( + dataset.other_eval_datasets.keys() + if hasattr(dataset, "other_eval_datasets") + else [] ) - eval_loss = [] - for batch in tqdm( - eval_loader, - desc="Evaluating...", - position=2, - leave=False, - total=eval_steps, - disable=jax.process_index() > 0, - ): - # need to keep only eval_batch_size_per_node items relevant to the node - batch = jax.tree_map( - lambda x: x.reshape( - (jax.process_count(), eval_batch_size_per_node) + x.shape[1:] - ), - batch, + val_datasets += ["eval"] + for val_dataset in val_datasets: + eval_loader = dataset.dataloader( + val_dataset, + eval_batch_size_per_step + * max(1, training_args.mp_devices // jax.local_device_count()), ) - batch = jax.tree_map(lambda x: x[jax.process_index()], batch) - - # add dp dimension when using "vmap trick" - if use_vmap_trick: - bs_shape = ( - jax.local_device_count() // training_args.mp_devices, - training_args.per_device_eval_batch_size, - ) + eval_steps = ( + len_eval_dataset // eval_batch_size_per_step + if len_eval_dataset is not None + else None + ) + eval_loss = [] + for batch in tqdm( + eval_loader, + desc="Evaluating...", + position=2, + leave=False, + total=eval_steps, + disable=jax.process_index() > 0, + ): + # need to keep only eval_batch_size_per_node items relevant to the node batch = jax.tree_map( - lambda x: x.reshape(bs_shape + x.shape[1:]), batch + lambda x: x.reshape( + (jax.process_count(), eval_batch_size_per_node) + + x.shape[1:] + ), + batch, ) + batch = jax.tree_map(lambda x: x[jax.process_index()], batch) - # freeze batch to pass safely to jax transforms - batch = freeze(batch) - # accumulate losses async - eval_loss.append(p_eval_step(state, batch)) + # add dp dimension when using "vmap trick" + if use_vmap_trick: + bs_shape = ( + jax.local_device_count() // training_args.mp_devices, + training_args.per_device_eval_batch_size, + ) + batch = jax.tree_map( + lambda x: x.reshape(bs_shape + x.shape[1:]), batch + ) - # get the mean of the loss - eval_loss = jnp.stack(eval_loss) - eval_loss = jnp.mean(eval_loss) - eval_metrics = {"loss": eval_loss} + # freeze batch to pass safely to jax transforms + batch = freeze(batch) + # accumulate losses async + eval_loss.append(p_eval_step(state, batch)) - # log metrics - metrics_logger.log(eval_metrics, prefix="eval") - metrics_logger.log_time("eval", time.perf_counter() - start_eval_time) + # get the mean of the loss + eval_loss = jnp.stack(eval_loss) + eval_loss = jnp.mean(eval_loss) + eval_metrics = {"loss": eval_loss} + + # log metrics + metrics_logger.log(eval_metrics, prefix=val_dataset) - # Print metrics and update progress bar - desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})" - epochs.write(desc) - epochs.desc = desc + # Print metrics and update progress bar + desc = f"Epoch... ({epoch + 1}/{num_epochs} | {val_dataset} Loss: {eval_metrics['loss']})" + epochs.write(desc) + epochs.desc = desc + + # log time + metrics_logger.log_time("eval", time.perf_counter() - start_eval_time) return eval_metrics