Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support multi validation datasets #192

Merged
merged 5 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion src/dalle_mini/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -34,8 +35,10 @@ class Dataset:
max_clip_score: float = None
filter_column: str = None
filter_value: str = None
multi_eval_ds: bool = False
train_dataset: Dataset = field(init=False)
eval_dataset: Dataset = field(init=False)
other_eval_datasets: list = field(init=False)
rng_dataset: jnp.ndarray = field(init=False)
multi_hosts: bool = field(init=False)

Expand Down Expand Up @@ -75,6 +78,21 @@ def __post_init__(self):
else:
data_files = None

# multiple validation datasets
if self.multi_eval_ds:
assert Path(
self.dataset_repo_or_path
).is_dir(), f"{self.dataset_repo_or_path} is not a directory, required for multi_eval_ds"
data_files = {
split.name: [str(f) for f in split.glob("*.parquet")]
for split in Path(self.dataset_repo_or_path).glob("*")
}
# rename "valid" to "validation" if present for consistency
if "valid" in data_files:
data_files["validation"] = data_files["valid"]
del data_files["valid"]
self.dataset_repo_or_path = "parquet"

# load dataset
dataset = load_dataset(
self.dataset_repo_or_path,
Expand Down Expand Up @@ -102,6 +120,11 @@ def __post_init__(self):
if self.streaming
else self.eval_dataset.select(range(self.max_eval_samples))
)
# other eval datasets
other_eval_splits = dataset.keys() - {"train", "validation"}
self.other_eval_datasets = {
split: dataset[split] for split in other_eval_splits
}

def preprocess(self, tokenizer, config):
# get required config variables
Expand Down Expand Up @@ -143,6 +166,20 @@ def preprocess(self, tokenizer, config):
)
),
)
if hasattr(self, "other_eval_datasets"):
self.other_eval_datasets = {
split: (
ds.filter(partial_filter_function)
if self.streaming
else ds.filter(
partial_filter_function,
num_proc=self.preprocessing_num_workers,
load_from_cache_file=not self.overwrite_cache,
desc="Filtering datasets",
)
)
for split, ds in self.other_eval_datasets.items()
}

# normalize text
if normalize_text:
Expand All @@ -168,6 +205,20 @@ def preprocess(self, tokenizer, config):
)
),
)
if hasattr(self, "other_eval_datasets"):
self.other_eval_datasets = {
split: (
ds.map(partial_normalize_function)
if self.streaming
else ds.map(
partial_normalize_function,
num_proc=self.preprocessing_num_workers,
load_from_cache_file=not self.overwrite_cache,
desc="Normalizing datasets",
)
)
for split, ds in self.other_eval_datasets.items()
}

# blank captions
if self.blank_caption_prob:
Expand Down Expand Up @@ -225,6 +276,29 @@ def preprocess(self, tokenizer, config):
)
),
)
if hasattr(self, "other_eval_datasets"):
self.other_eval_datasets = {
split: (
ds.map(
partial_preprocess_function,
batched=True,
remove_columns=[
self.text_column,
self.encoding_column,
],
)
if self.streaming
else ds.map(
partial_preprocess_function,
batched=True,
remove_columns=getattr(ds, "column_names"),
num_proc=self.preprocessing_num_workers,
load_from_cache_file=not self.overwrite_cache,
desc="Preprocessing datasets",
)
)
for split, ds in self.other_eval_datasets.items()
}

def dataloader(self, split, batch_size, epoch=None):
def _dataloader_datasets_non_streaming(
Expand Down Expand Up @@ -283,7 +357,7 @@ def _dataloader_datasets_streaming(
elif split == "eval":
ds = self.eval_dataset
else:
raise ValueError(f'split must be "train" or "eval", got {split}')
ds = self.other_eval_datasets[split]

if self.streaming:
return _dataloader_datasets_streaming(ds, epoch)
Expand Down
113 changes: 65 additions & 48 deletions tools/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,12 @@ class DataTrainingArguments:
default=None,
metadata={"help": "Class value to be kept during filtering."},
)
multi_eval_ds: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to look for multiple validation datasets (local support only)."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
Expand Down Expand Up @@ -1383,62 +1389,73 @@ def run_evaluation():
# ======================== Evaluating ==============================
if training_args.do_eval:
start_eval_time = time.perf_counter()
eval_loader = dataset.dataloader(
"eval",
eval_batch_size_per_step
* max(1, training_args.mp_devices // jax.local_device_count()),
)
eval_steps = (
len_eval_dataset // eval_batch_size_per_step
if len_eval_dataset is not None
else None
# get validation datasets
val_datasets = list(
dataset.other_eval_datasets.keys()
if hasattr(dataset, "other_eval_datasets")
else []
)
eval_loss = []
for batch in tqdm(
eval_loader,
desc="Evaluating...",
position=2,
leave=False,
total=eval_steps,
disable=jax.process_index() > 0,
):
# need to keep only eval_batch_size_per_node items relevant to the node
batch = jax.tree_map(
lambda x: x.reshape(
(jax.process_count(), eval_batch_size_per_node) + x.shape[1:]
),
batch,
val_datasets += ["eval"]
for val_dataset in val_datasets:
eval_loader = dataset.dataloader(
val_dataset,
eval_batch_size_per_step
* max(1, training_args.mp_devices // jax.local_device_count()),
)
batch = jax.tree_map(lambda x: x[jax.process_index()], batch)

# add dp dimension when using "vmap trick"
if use_vmap_trick:
bs_shape = (
jax.local_device_count() // training_args.mp_devices,
training_args.per_device_eval_batch_size,
)
eval_steps = (
len_eval_dataset // eval_batch_size_per_step
if len_eval_dataset is not None
else None
)
eval_loss = []
for batch in tqdm(
eval_loader,
desc="Evaluating...",
position=2,
leave=False,
total=eval_steps,
disable=jax.process_index() > 0,
):
# need to keep only eval_batch_size_per_node items relevant to the node
batch = jax.tree_map(
lambda x: x.reshape(bs_shape + x.shape[1:]), batch
lambda x: x.reshape(
(jax.process_count(), eval_batch_size_per_node)
+ x.shape[1:]
),
batch,
)
batch = jax.tree_map(lambda x: x[jax.process_index()], batch)

# freeze batch to pass safely to jax transforms
batch = freeze(batch)
# accumulate losses async
eval_loss.append(p_eval_step(state, batch))
# add dp dimension when using "vmap trick"
if use_vmap_trick:
bs_shape = (
jax.local_device_count() // training_args.mp_devices,
training_args.per_device_eval_batch_size,
)
batch = jax.tree_map(
lambda x: x.reshape(bs_shape + x.shape[1:]), batch
)

# get the mean of the loss
eval_loss = jnp.stack(eval_loss)
eval_loss = jnp.mean(eval_loss)
eval_metrics = {"loss": eval_loss}
# freeze batch to pass safely to jax transforms
batch = freeze(batch)
# accumulate losses async
eval_loss.append(p_eval_step(state, batch))

# log metrics
metrics_logger.log(eval_metrics, prefix="eval")
metrics_logger.log_time("eval", time.perf_counter() - start_eval_time)
# get the mean of the loss
eval_loss = jnp.stack(eval_loss)
eval_loss = jnp.mean(eval_loss)
eval_metrics = {"loss": eval_loss}

# log metrics
metrics_logger.log(eval_metrics, prefix=val_dataset)

# Print metrics and update progress bar
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
epochs.write(desc)
epochs.desc = desc
# Print metrics and update progress bar
desc = f"Epoch... ({epoch + 1}/{num_epochs} | {val_dataset} Loss: {eval_metrics['loss']})"
epochs.write(desc)
epochs.desc = desc

# log time
metrics_logger.log_time("eval", time.perf_counter() - start_eval_time)

return eval_metrics

Expand Down