Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve stopping_strategy of shuffled interleaved dataset (random cycling case) #5816

Merged
merged 2 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,12 @@
from .utils.hub import hf_hub_url
from .utils.info_utils import is_small_dataset
from .utils.metadata import DatasetMetadata
from .utils.py_utils import asdict, convert_file_size_to_int, iflatmap_unordered, unique_values
from .utils.py_utils import Literal, asdict, convert_file_size_to_int, iflatmap_unordered, unique_values
from .utils.stratify import stratified_shuffle_split_generate_indices
from .utils.tf_utils import dataset_to_tf, minimal_tf_collate_fn, multiprocess_dataset_to_tf
from .utils.typing import PathLike


try:
from typing import Literal
except ImportError:
from typing_extensions import Literal

if TYPE_CHECKING:
import sqlite3

Expand Down Expand Up @@ -3085,7 +3080,7 @@ def load_processed_shard_from_cache(shard_kwargs):
else:

def format_cache_file_name(
cache_file_name: Optional[str], rank: Union[int, Literal["*"]]
cache_file_name: Optional[str], rank: Union[int, Literal["*"]] # noqa: F722
) -> Optional[str]:
if not cache_file_name:
return cache_file_name
Expand Down Expand Up @@ -5980,7 +5975,7 @@ def _interleave_map_style_datasets(
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Optional[str] = "first_exhausted",
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
**kwargs,
) -> "Dataset":
"""
Expand All @@ -5996,7 +5991,7 @@ def _interleave_map_style_datasets(
seed (`int`, optional, default None): The random seed used to choose a source for each example.
info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
split (:class:`NamedSplit`, optional): Name of the dataset split.
stopping_strategy (Optional `str`, defaults to `first_exhausted`):
stopping_strategy (`str`, defaults to `first_exhausted`):
Two strategies are proposed right now.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
Expand Down
5 changes: 3 additions & 2 deletions src/datasets/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .iterable_dataset import IterableDataset, _concatenate_iterable_datasets, _interleave_iterable_datasets
from .splits import NamedSplit
from .utils import logging
from .utils.py_utils import Literal


logger = logging.get_logger(__name__)
Expand All @@ -19,7 +20,7 @@ def interleave_datasets(
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Optional[str] = "first_exhausted",
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
) -> DatasetType:
"""
Interleave several datasets (sources) into a single dataset.
Expand Down Expand Up @@ -52,7 +53,7 @@ def interleave_datasets(
split ([`NamedSplit`], *optional*):
Name of the dataset split.
<Added version="2.4.0"/>
stopping_strategy (`str`, *optional*, defaults to `first_exhausted`):
stopping_strategy (`str`, defaults to `first_exhausted`):
Two strategies are proposed right now, `first_exhausted` and `all_exhausted`.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
Expand Down
30 changes: 18 additions & 12 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .splits import NamedSplit
from .table import table_cast
from .utils.logging import get_logger
from .utils.py_utils import Literal
from .utils.sharding import _merge_gen_kwargs, _number_of_shards_in_gen_kwargs, _shuffle_gen_kwargs, _split_gen_kwargs


Expand Down Expand Up @@ -52,7 +53,7 @@ def _batch_to_examples(batch: Dict[str, list]) -> List[Dict[str, Any]]:
yield {col: array[i] for col, array in batch.items()}


class HasNextIterator(Iterator):
class _HasNextIterator(Iterator):
"""Iterator with an hasnext() function. Taken from https://stackoverflow.com/questions/1966591/has-next-in-python-iterators."""

def __init__(self, it):
Expand Down Expand Up @@ -202,7 +203,9 @@ def n_shards(self) -> int:

class CyclingMultiSourcesExamplesIterable(_BaseExamplesIterable):
def __init__(
self, ex_iterables: List[_BaseExamplesIterable], stopping_strategy: Optional[str] = "first_exhausted"
self,
ex_iterables: List[_BaseExamplesIterable],
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
):
self.ex_iterables = ex_iterables
self.stopping_strategy = stopping_strategy
Expand All @@ -211,14 +214,14 @@ def __init__(
# if oversampling ("all_exhausted"), we stop as soons as every dataset is exhausted, i.e as soon as every samples of every dataset has been visited at least once
self.bool_strategy_func = np.all if (stopping_strategy == "all_exhausted") else np.any

def _give_indice_iterator(self):
def _get_indices_iterator(self):
# this is an infinite iterator to keep track of which iterator we want to pick examples from
return cycle(range(len(self.ex_iterables)))

def __iter__(self):
iterators = [HasNextIterator(ex_iterable) for ex_iterable in self.ex_iterables]
iterators = [_HasNextIterator(ex_iterable) for ex_iterable in self.ex_iterables]

indices_iterator = self._give_indice_iterator()
indices_iterator = self._get_indices_iterator()

is_exhausted = np.full(len(self.ex_iterables), False)
for i in indices_iterator:
Expand All @@ -233,7 +236,7 @@ def __iter__(self):
# if the stopping criteria is met, break the main for loop
break
# otherwise reinitialise the iterator and yield the first example
iterators[i] = HasNextIterator(self.ex_iterables[i])
iterators[i] = _HasNextIterator(self.ex_iterables[i])

except StopIteration:
# here it means that the i-th iterabledataset is empty, i.e we never have the occasion to yield an element of the i-th dataset.
Expand Down Expand Up @@ -381,7 +384,7 @@ def __init__(
ex_iterables,
generator: np.random.Generator,
probabilities: Optional[List[float]] = None,
stopping_strategy: Optional[str] = "first_exhausted",
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
):
super().__init__(ex_iterables, stopping_strategy)
self.generator = deepcopy(generator)
Expand All @@ -402,7 +405,7 @@ def _iter_random_indices(
while True:
yield from (int(i) for i in rng.choice(num_sources, size=random_batch_size, p=p))

def _give_indice_iterator(self):
def _get_indices_iterator(self):
rng = deepcopy(self.generator)
# this is an infinite iterator that randomly samples the index of the source to pick examples from
return self._iter_random_indices(rng, len(self.ex_iterables), p=self.probabilities)
Expand All @@ -411,7 +414,10 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RandomlyCycli
"""Shuffle the data sources of each wrapped examples iterable."""
ex_iterables = [ex_iterable.shuffle_data_sources(generator) for ex_iterable in self.ex_iterables]
return RandomlyCyclingMultiSourcesExamplesIterable(
ex_iterables, generator=generator, probabilities=self.probabilities
ex_iterables,
generator=generator,
probabilities=self.probabilities,
stopping_strategy=self.stopping_strategy,
)

def shard_data_sources(self, worker_id: int, num_workers: int) -> "RandomlyCyclingMultiSourcesExamplesIterable":
Expand Down Expand Up @@ -1824,7 +1830,7 @@ def _interleave_iterable_datasets(
seed: Optional[int] = None,
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
stopping_strategy: Optional[str] = "first_exhausted",
stopping_strategy: Literal["first_exhausted", "all_exhausted"] = "first_exhausted",
) -> IterableDataset:
"""
Interleave several iterable datasets (sources) into a single iterable dataset.
Expand All @@ -1839,7 +1845,7 @@ def _interleave_iterable_datasets(
probabilities (`List[float]`, optional, default None): If specified, the new iterable dataset samples
examples from one source at a time according to these probabilities.
seed (`int`, optional, default None): The random seed used to choose a source for each example.
stopping_strategy (Optional `str`, defaults to `first_exhausted`):
stopping_strategy (`str`, defaults to `first_exhausted`):
Two strategies are proposed right now.
By default, `first_exhausted` is an undersampling strategy, i.e the dataset construction is stopped as soon as one dataset has ran out of samples.
If the strategy is `all_exhausted`, we use an oversampling strategy, i.e the dataset construction is stopped as soon as every samples of every dataset has been added at least once.
Expand All @@ -1863,7 +1869,7 @@ def _interleave_iterable_datasets(

ex_iterables = [d._ex_iterable for d in datasets]

# Use cycling or random cycling or sources
# Use cycling or random cycling of sources
if probabilities is None:
ex_iterable = CyclingMultiSourcesExamplesIterable(ex_iterables, stopping_strategy=stopping_strategy)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/packaged_modules/csv/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import pandas as pd
import pyarrow as pa
from typing_extensions import Literal

import datasets
import datasets.config
from datasets.features.features import require_storage_cast
from datasets.table import table_cast
from datasets.utils.py_utils import Literal


logger = datasets.utils.logging.get_logger(__name__)
Expand Down