Skip to content

Commit

Permalink
Support multimodal airio vocabularies.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576217024
  • Loading branch information
texasmichelle authored and t5-copybara committed Oct 24, 2023
1 parent 0a56776 commit b295e55
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 27 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
},
scripts=[],
install_requires=[
'airio @ git+https://github.com/google/airio#egg=airio',
'absl-py',
'cached_property',
'clu @ git+https://github.com/google/CommonLoopUtils#egg=clu',
Expand Down
14 changes: 9 additions & 5 deletions t5x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,9 +305,9 @@ def _verify_matching_vocabs(cfg: utils.DatasetConfig):
partitioner=partitioner,
data_layout=data_layout,
)

input_shapes = jax.tree_map(
lambda x: (data_layout.batch_size, *x.shape[1:]), train_iter.element_spec
lambda x: (data_layout.batch_size, *x.shape[1:]),
train_iter.element_spec,
)
input_types = jax.tree_map(lambda x: x.dtype, train_iter.element_spec)

Expand Down Expand Up @@ -513,9 +513,13 @@ def _run_training_eval(first_run: bool = False):
}
)
logging.info('Computing training evaluation metrics.')
eval_batch_iters = {
task: ds.as_numpy_iterator() for task, ds in train_eval_datasets.items()
}
eval_batch_iters = {}
for task, ds in train_eval_datasets.items():
if isinstance(ds, tf.data.Dataset):
eval_batch_iters[task] = ds.as_numpy_iterator()
else:
eval_batch_iters[task] = ds

eval_summaries = trainer.eval(eval_batch_iters)
trainer.stop_training = run_actions(
trainer_lib.ActionMode.TRAIN_EVAL, # pytype: disable=wrong-arg-types # jax-ndarray
Expand Down
92 changes: 70 additions & 22 deletions t5x/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from absl import flags
from absl import logging
import airio
import clu.data
import flax
from flax import traverse_util
Expand Down Expand Up @@ -519,7 +520,9 @@ def restore(
class DatasetConfig:
"""Configuration for loading a dataset from a SeqIO Task or Mixture."""

mixture_or_task_name: Union[str, seqio.Task, seqio.Mixture]
mixture_or_task_name: Union[
str, seqio.Task, seqio.Mixture, airio.Task, airio.Mixture
]
task_feature_lengths: Mapping[str, int]
split: str
batch_size: int # Number of examples per batch.
Expand Down Expand Up @@ -698,13 +701,9 @@ def restore(self, filename):

@property
def iterator(self):
return (
self._iterator.iterator
if isinstance(
self._iterator, clu.data.dataset_iterator.TfDatasetIterator
)
else self._iterator
)
if isinstance(self._iterator, clu.data.dataset_iterator.TfDatasetIterator):
return self._iterator.iterator
return self._iterator


def prepare_train_iter(
Expand All @@ -717,6 +716,8 @@ def prepare_train_iter(
data_layout,
) -> clu.data.dataset_iterator.PeekableDatasetIterator:
"""Prepares the training input iterator."""
if isinstance(train_iter, airio.PyGrainDatasetIteratorWrapper):
return train_iter
if isinstance(train_iter, tf.data.Dataset):
train_iter = clu.data.dataset_iterator.TfDatasetIterator(
train_iter, checkpoint=True
Expand Down Expand Up @@ -789,13 +790,20 @@ def get_zeros_batch_like_spec(


def get_zeros_batch_like_dataset(
dataset: tf.data.Dataset, batch_size=None
dataset: Union[tf.data.Dataset, airio.PyGrainDatasetIteratorWrapper],
batch_size=None,
) -> Mapping[str, jnp.ndarray]:
"""Get zeros batch like the dataset spec."""
reshape = lambda s: (batch_size,) + s[1:] if batch_size else tuple(s)
batch_spec = {
k: jax.ShapeDtypeStruct(reshape(t.shape), t.dtype.as_numpy_dtype)
for k, t in dataset.element_spec.items()
}
batch_spec = {}
for key, val in dataset.element_spec.items(): # pytype: disable=attribute-error
if isinstance(dataset, tf.data.Dataset):
static_attributes = jax.ShapeDtypeStruct(
reshape(val.shape), val.dtype.as_numpy_dtype
)
else:
static_attributes = jax.ShapeDtypeStruct(val.shape, val.dtype)
batch_spec[key] = static_attributes
return get_zeros_batch_like_spec(batch_spec)


Expand Down Expand Up @@ -1819,11 +1827,20 @@ def get_vocabulary(
)
import_module(cfg.module)

if isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase):
if isinstance(cfg.mixture_or_task_name, airio.DatasetProviderBase):
mixture_or_task = cfg.mixture_or_task_name
vocab_map = airio.get_vocabularies(mixture_or_task)
if not vocab_map:
raise ValueError(
f'No vocabularies found for AirIO task/mixture {mixture_or_task}'
)
features = {k: seqio.Feature(vocabulary=v) for k, v in vocab_map.items()}
elif isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase):
mixture_or_task = cfg.mixture_or_task_name
features = mixture_or_task.output_features
else:
mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name)
features = mixture_or_task.output_features
features = mixture_or_task.output_features

if 'inputs' in features and 'targets' in features:
return (features['inputs'].vocabulary, features['targets'].vocabulary)
Expand Down Expand Up @@ -1900,7 +1917,7 @@ def get_dataset(
feature_converter_cls: Callable[..., seqio.FeatureConverter],
num_epochs: Optional[int] = None,
continue_from_last_checkpoint: bool = False,
) -> tf.data.Dataset:
) -> Union[tf.data.Dataset, clu.data.dataset_iterator.DatasetIterator]:
"""Returns a dataset from SeqIO based on a `DatasetConfig`."""
if continue_from_last_checkpoint:
raise ValueError(
Expand Down Expand Up @@ -1941,7 +1958,10 @@ def get_dataset_inner(
):
"""Internal fn to load a dataset from SeqIO based on a `DatasetConfig`."""
batch_size = cfg.batch_size // shard_info.num_shards
if isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase):
if isinstance(
cfg.mixture_or_task_name,
(seqio.DatasetProviderBase, airio.DatasetProviderBase),
):
mixture_or_task = cfg.mixture_or_task_name
else:
mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name)
Expand Down Expand Up @@ -1993,7 +2013,7 @@ def __call__(
feature_converter_cls: Callable[..., seqio.FeatureConverter],
num_epochs: Optional[int] = None,
continue_from_last_checkpoint: bool = True,
) -> Union[clu.data.dataset_iterator.DatasetIterator, tf.data.Dataset]:
) -> Union[clu.data.dataset_iterator.DatasetIterator, tf.data.Dataset,]:
...


Expand All @@ -2007,7 +2027,9 @@ def __call__(
num_shards: int,
eval_steps: int,
feature_converter_cls: Callable[..., seqio.FeatureConverter],
) -> Mapping[str, tf.data.Dataset]:
) -> Mapping[
str, Union[tf.data.Dataset, airio.PyGrainDatasetIteratorWrapper]
]:
...


Expand All @@ -2020,9 +2042,19 @@ def get_training_eval_datasets(
deterministic: bool = False,
model_dir: Optional[str] = None,
start_step: int = 0,
) -> Mapping[str, tf.data.Dataset]:
) -> Mapping[
str,
Union[
tf.data.Dataset,
clu.data.dataset_iterator.DatasetIterator,
Iterable[Mapping[str, np.ndarray]],
],
]:
"""Returns a mapping from eval task name to its dataset."""
if isinstance(cfg.mixture_or_task_name, seqio.DatasetProviderBase):
if isinstance(
cfg.mixture_or_task_name,
(seqio.DatasetProviderBase, airio.DatasetProviderBase),
):
mixture_or_task = cfg.mixture_or_task_name
else:
mixture_or_task = seqio.get_mixture_or_task(cfg.mixture_or_task_name)
Expand All @@ -2034,6 +2066,22 @@ def get_training_eval_datasets(
get_deterministic_dataset, model_dir=model_dir, start_step=start_step
)

if isinstance(mixture_or_task, (airio.Task, airio.Mixture)):
data_iter = get_dataset_fn(
dataclasses.replace(cfg, batch_size=1),
shard_id=0,
num_shards=1,
feature_converter_cls=feature_converter_cls,
num_epochs=eval_steps * cfg.batch_size,
continue_from_last_checkpoint=False,
)
# TODO(b/304579895): Cannot use itertools.islice here to limit the number
# of records to eval_steps since peek() is required later. Instead, update
# t5x eval to not depend on the data pipeline and stop after eval_run
# steps.
datasets[mixture_or_task.name] = data_iter
return datasets

if cfg.batch_size % num_shards:
raise ValueError(
f'Batch size ({cfg.batch_size}) must be divisible by number of '
Expand Down Expand Up @@ -2125,7 +2173,7 @@ def __getitem__(self, key: str) -> Any:
def __len__(self) -> int:
return len(self._kvs)

def __iter__(self) -> Iterable[Tuple[re.Pattern, Any]]:
def __iter__(self) -> Iterable[Tuple[re.Pattern[str], Any]]:
return iter(self._kvs)


Expand Down

0 comments on commit b295e55

Please sign in to comment.