Skip to content

Commit

Permalink
Unify load_from_cache_file type and logic (huggingface#5515)
Browse files Browse the repository at this point in the history
* Updating type annotations for `load_from_cache_file`
* Added logic for cache checking if needed
* Updated documentation following the wording of `Dataset.map`

Co-authored-by: Mario Šaško <mariosasko777@gmail.com>
  • Loading branch information
2 people authored and Filip Haltmayer committed Feb 16, 2023
1 parent c5b7f06 commit 65bad7e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
28 changes: 17 additions & 11 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1837,7 +1837,7 @@ def cast(
features: Features,
batch_size: Optional[int] = 1000,
keep_in_memory: bool = False,
load_from_cache_file: bool = True,
load_from_cache_file: Optional[bool] = None,
cache_file_name: Optional[str] = None,
writer_batch_size: Optional[int] = 1000,
num_proc: Optional[int] = None,
Expand Down Expand Up @@ -2709,7 +2709,7 @@ def map(
drop_last_batch: bool = False,
remove_columns: Optional[Union[str, List[str]]] = None,
keep_in_memory: bool = False,
load_from_cache_file: bool = None,
load_from_cache_file: Optional[bool] = None,
cache_file_name: Optional[str] = None,
writer_batch_size: Optional[int] = 1000,
features: Optional[Features] = None,
Expand Down Expand Up @@ -2768,7 +2768,7 @@ def map(
columns with names in `remove_columns`, these columns will be kept.
keep_in_memory (`bool`, defaults to `False`):
Keep the dataset in memory instead of writing it to a cache file.
load_from_cache_file (`bool`, defaults to `True` if caching is enabled):
load_from_cache_file (`Optioanl[bool]`, defaults to `True` if caching is enabled):
If a cache file storing the current computation from `function`
can be identified, use it instead of recomputing.
cache_file_name (`str`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -3365,7 +3365,7 @@ def filter(
batched: bool = False,
batch_size: Optional[int] = 1000,
keep_in_memory: bool = False,
load_from_cache_file: bool = True,
load_from_cache_file: Optional[bool] = None,
cache_file_name: Optional[str] = None,
writer_batch_size: Optional[int] = 1000,
fn_kwargs: Optional[dict] = None,
Expand Down Expand Up @@ -3399,7 +3399,7 @@ def filter(
If `batch_size <= 0` or `batch_size == None`, provide the full dataset as a single batch to `function`.
keep_in_memory (`bool`, defaults to `False`):
Keep the dataset in memory instead of writing it to a cache file.
load_from_cache_file (`bool`, defaults to `True`):
load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled):
If a cache file storing the current computation from `function`
can be identified, use it instead of recomputing.
cache_file_name (`str`, *optional*):
Expand Down Expand Up @@ -3807,7 +3807,7 @@ def sort(
kind: str = None,
null_placement: str = "last",
keep_in_memory: bool = False,
load_from_cache_file: bool = True,
load_from_cache_file: Optional[bool] = None,
indices_cache_file_name: Optional[str] = None,
writer_batch_size: Optional[int] = 1000,
new_fingerprint: Optional[str] = None,
Expand All @@ -3833,7 +3833,7 @@ def sort(
<Added version="1.14.2"/>
keep_in_memory (`bool`, defaults to `False`):
Keep the sorted indices in memory instead of writing it to a cache file.
load_from_cache_file (`bool`, defaults to `True`):
load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled):
If a cache file storing the sorted indices
can be identified, use it instead of recomputing.
indices_cache_file_name (`str`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -3872,6 +3872,8 @@ def sort(
f"Column '{column}' not found in the dataset. Please provide a column selected in: {self._data.column_names}"
)

load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled()

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if indices_cache_file_name is None:
Expand Down Expand Up @@ -3909,7 +3911,7 @@ def shuffle(
seed: Optional[int] = None,
generator: Optional[np.random.Generator] = None,
keep_in_memory: bool = False,
load_from_cache_file: bool = True,
load_from_cache_file: Optional[bool] = None,
indices_cache_file_name: Optional[str] = None,
writer_batch_size: Optional[int] = 1000,
new_fingerprint: Optional[str] = None,
Expand Down Expand Up @@ -3957,7 +3959,7 @@ def shuffle(
If `generator=None` (default), uses `np.random.default_rng` (the default BitGenerator (PCG64) of NumPy).
keep_in_memory (`bool`, default `False`):
Keep the shuffled indices in memory instead of writing it to a cache file.
load_from_cache_file (`bool`, defaults to `True`):
load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled):
If a cache file storing the shuffled indices
can be identified, use it instead of recomputing.
indices_cache_file_name (`str`, *optional*):
Expand Down Expand Up @@ -4002,6 +4004,8 @@ def shuffle(
if generator is not None and not isinstance(generator, np.random.Generator):
raise ValueError("The provided generator must be an instance of numpy.random.Generator")

load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled()

if generator is None:
if seed is None:
_, seed, pos, *_ = np.random.get_state()
Expand Down Expand Up @@ -4046,7 +4050,7 @@ def train_test_split(
seed: Optional[int] = None,
generator: Optional[np.random.Generator] = None,
keep_in_memory: bool = False,
load_from_cache_file: bool = True,
load_from_cache_file: Optional[bool] = None,
train_indices_cache_file_name: Optional[str] = None,
test_indices_cache_file_name: Optional[str] = None,
writer_batch_size: Optional[int] = 1000,
Expand Down Expand Up @@ -4083,7 +4087,7 @@ def train_test_split(
If `generator=None` (default), uses `np.random.default_rng` (the default BitGenerator (PCG64) of NumPy).
keep_in_memory (`bool`, defaults to `False`):
Keep the splits indices in memory instead of writing it to a cache file.
load_from_cache_file (`bool`, defaults to `True`):
load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled):
If a cache file storing the splits indices
can be identified, use it instead of recomputing.
train_cache_file_name (`str`, *optional*):
Expand Down Expand Up @@ -4223,6 +4227,8 @@ def train_test_split(
"aforementioned parameters."
)

load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled()

if generator is None and shuffle is True:
if seed is None:
_, seed, pos, *_ = np.random.get_state()
Expand Down
16 changes: 8 additions & 8 deletions src/datasets/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def map(
drop_last_batch: bool = False,
remove_columns: Optional[Union[str, List[str]]] = None,
keep_in_memory: bool = False,
load_from_cache_file: bool = True,
load_from_cache_file: Optional[bool] = None,
cache_file_names: Optional[Dict[str, Optional[str]]] = None,
writer_batch_size: Optional[int] = 1000,
features: Optional[Features] = None,
Expand Down Expand Up @@ -801,7 +801,7 @@ def map(
columns with names in `remove_columns`, these columns will be kept.
keep_in_memory (`bool`, defaults to `False`):
Keep the dataset in memory instead of writing it to a cache file.
load_from_cache_file (`bool`, defaults to `True`):
load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled):
If a cache file storing the current computation from `function`
can be identified, use it instead of recomputing.
cache_file_names (`[Dict[str, str]]`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -881,7 +881,7 @@ def filter(
batched: bool = False,
batch_size: Optional[int] = 1000,
keep_in_memory: bool = False,
load_from_cache_file: bool = True,
load_from_cache_file: Optional[bool] = None,
cache_file_names: Optional[Dict[str, Optional[str]]] = None,
writer_batch_size: Optional[int] = 1000,
fn_kwargs: Optional[dict] = None,
Expand Down Expand Up @@ -911,7 +911,7 @@ def filter(
`batch_size <= 0` or `batch_size == None` then provide the full dataset as a single batch to `function`.
keep_in_memory (`bool`, defaults to `False`):
Keep the dataset in memory instead of writing it to a cache file.
load_from_cache_file (`bool`, defaults to `True`):
load_from_cache_file (`Optional[bool]`, defaults to `True` if chaching is enabled):
If a cache file storing the current computation from `function`
can be identified, use it instead of recomputing.
cache_file_names (`[Dict[str, str]]`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -982,7 +982,7 @@ def sort(
kind: str = None,
null_placement: str = "last",
keep_in_memory: bool = False,
load_from_cache_file: bool = True,
load_from_cache_file: Optional[bool] = None,
indices_cache_file_names: Optional[Dict[str, Optional[str]]] = None,
writer_batch_size: Optional[int] = 1000,
) -> "DatasetDict":
Expand All @@ -1008,7 +1008,7 @@ def sort(
<Added version="1.14.2"/>
keep_in_memory (`bool`, defaults to `False`):
Keep the sorted indices in memory instead of writing it to a cache file.
load_from_cache_file (`bool`, defaults to `True`):
load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled):
If a cache file storing the sorted indices
can be identified, use it instead of recomputing.
indices_cache_file_names (`[Dict[str, str]]`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -1056,7 +1056,7 @@ def shuffle(
seed: Optional[int] = None,
generators: Optional[Dict[str, np.random.Generator]] = None,
keep_in_memory: bool = False,
load_from_cache_file: bool = True,
load_from_cache_file: Optional[bool] = None,
indices_cache_file_names: Optional[Dict[str, Optional[str]]] = None,
writer_batch_size: Optional[int] = 1000,
) -> "DatasetDict":
Expand All @@ -1081,7 +1081,7 @@ def shuffle(
You have to provide one `generator` per dataset in the dataset dictionary.
keep_in_memory (`bool`, defaults to `False`):
Keep the dataset in memory instead of writing it to a cache file.
load_from_cache_file (`bool`, defaults to `True`):
load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled):
If a cache file storing the current computation from `function`
can be identified, use it instead of recomputing.
indices_cache_file_names (`Dict[str, str]`, *optional*):
Expand Down

0 comments on commit 65bad7e

Please sign in to comment.