diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index f50154b2..a731f05e 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -130,7 +130,7 @@ def splits(self) -> Sequence[str]: def get_dataset( self, sequence_length: Optional[Mapping[str, int]] = None, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error use_cached: bool = False, shuffle: bool = True, seed: Optional[int] = None, @@ -173,7 +173,7 @@ def add_provider(cls, name: str, provider): task_registry_provenance_tracking.maybe_record_provenance( frame=inspect.currentframe(), name=name, - provider_type=provider.__class__.__name__, + provider_type=provider.__class__.__name__, # pylint:disable=attribute-error ) @classmethod @@ -325,7 +325,7 @@ def list_shards(self, split: str) -> Sequence[str]: @abc.abstractmethod def get_dataset( self, # pytype: disable=signature-mismatch # overriding-default-value-checks - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, @@ -432,7 +432,7 @@ def __repr__(self): def get_dataset( self, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, @@ -550,7 +550,7 @@ def get_dataset( num_epochs: Optional[int] = 1, # Unused ) -> tf.data.Dataset: if split is None: - split = tfds.Split.TRAIN + split = tfds.Split.TRAIN # pylint:disable=attribute-error return self.tfds_dataset.load( split, shuffle_files=shuffle, seed=seed, shard_info=shard_info ) @@ -639,7 +639,7 @@ def __repr__(self): def get_dataset( self, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, @@ -694,7 +694,7 @@ def list_shards(self, split: str) -> Sequence[str]: return _list_files(pattern=filepattern) if not any(glob.has_magic(f) for f in filepattern): - return filepattern + return filepattern # pytype: disable=bad-return-type else: return _list_files(pattern=filepattern) @@ -1512,7 +1512,7 @@ def assert_cached(self) -> None: ), f"'{self.name}' does not exist in any of the task cache directories." def get_cached_stats( - self, split: str = tfds.Split.TRAIN + self, split: str = tfds.Split.TRAIN # pylint:disable=attribute-error ) -> Mapping[str, Union[int, float]]: """Returns basic statistics for cached dataset.""" self.assert_cached() @@ -1526,10 +1526,10 @@ def get_cached_stats( self._stats[split] = json.load(f) return self._stats[split] - def get_dataset( + def get_dataset( # pylint: disable=arguments-renamed self, # pytype: disable=signature-mismatch # overriding-default-value-checks sequence_length: Optional[Mapping[str, int]] = None, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error use_cached: bool = False, shuffle: bool = True, shuffle_buffer_size: Optional[int] = None, # Unique to Task @@ -1614,7 +1614,7 @@ def get_dataset( ) else: ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed) - ds = ds.shard(shard_info.num_shards, shard_info.index) + ds = ds.shard(shard_info.num_shards, shard_info.index) # pylint:disable=attribute-error num_shards = shard_info.num_shards if shard_info else 1 if try_in_mem_cache and ( @@ -1915,7 +1915,7 @@ def get_task_dataset( task: Task, output_feature_keys: Set[str], sequence_length: Optional[Mapping[str, int]] = None, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error use_cached: bool = False, shuffle: bool = True, seed: Optional[int] = None, @@ -1947,7 +1947,7 @@ def _get_all_mixing_rates(self, tasks): def get_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks self, sequence_length: Optional[Mapping[str, int]] = None, - split: str = tfds.Split.TRAIN, + split: str = tfds.Split.TRAIN, # pylint:disable=attribute-error use_cached: bool = False, shuffle: bool = True, seed: Optional[int] = None, @@ -2115,6 +2115,98 @@ def _get_submixture_rate(self, mix: "Mixture") -> float: return float(rate) +def get_dataset_iterator_from_tasks( + tasks: Union[ + Sequence[SubtaskOrName], Sequence[Tuple[SubtaskOrName, MixtureRate]] + ], + sources: Sequence[grain.TfDataSource], + proportions: Sequence[float], + shard_info: Optional[ShardInfo], + seed: Optional[int], + num_epochs: Optional[int], + strict_transformations: bool, + shuffle: bool, + batch_size: Optional[int], + sequence_length: Optional[Mapping[str, int]], + trim_output_features: bool, + output_features: Mapping[str, str], + feature_converter: FeatureConverter, +) -> grain.TfGrainDatasetIterator: + """Returns a deterministic DatasetIterator for the mixture.""" + if shard_info is None: + shard_options = grain.NoSharding() + else: + shard_options = grain.ShardOptions( + shard_index=shard_info.index, shard_count=shard_info.num_shards + ) + + if num_epochs and num_epochs != 1: + raise ValueError( + "Epochs are not supported for mixtures. A mixture " + "always repeats indefinitely over it's tasks." + ) + + if sequence_length is not None: + # Avoid index being dropped. In case of example packing we even need to + # pack it (but it should never be the limiting factor). + sequence_length = dict(sequence_length) + sequence_length[grain.INDEX] = max(sequence_length.values()) + + extra_args = { + "sequence_length": sequence_length, + "output_features": output_features, + } + add_kwargs = lambda t: utils.add_kwargs_to_transform(t, **extra_args) + + transformations_per_source = [] + for task in tasks: + transformations_per_source.append( + [add_kwargs(t) for t in task.preprocessors] # pytype: disable=attribute-error + ) # pylint: disable=protected-access + # Transformations applied after combination all data sources. + transformations = [ + seqio_preprocessors.ReshapeFeatures({grain.INDEX: [-1]}), + seqio_preprocessors.DropFeatures( + set(grain.META_FEATURES) - {grain.INDEX} + ), + ] + if trim_output_features: + transformations.append(seqio_preprocessors._TrimDataset()) # pylint: disable=protected-access + if hasattr(feature_converter, "get_grain_transforms"): + transformations += feature_converter.get_grain_transforms( + batch_size=batch_size, task_feature_lengths=sequence_length + ) + elif strict_transformations: + raise NotImplementedError( + f"FeatureConverter {feature_converter} does " + "not implement get_grain_transforms()." + ) + else: + transformations += [ + functools.partial( + feature_converter, task_feature_lengths=sequence_length + ) + ] + transformations = [add_kwargs(t) for t in transformations] + + sampler = grain.TfMixtureIndexSampler( + [len(s) for s in sources], + shard_options=shard_options, + proportions=proportions, + shuffle=shuffle, + seed=seed, + ) + data_loader = grain.TfMixtureDataLoader( + sources=sources, + sampler=sampler, + transformations_per_source=transformations_per_source, + transformations=transformations, + iterator_options=grain.IteratorOptions(drop_grain_meta_features=True), + strict_transformations=strict_transformations, + ) + return iter(data_loader) # pytype: disable=bad-return-type + + def _log_padding_fractions(dataset, sequence_length, num_examples=100):