From 65bad7e4df14e251e591a95db5e988613320d80d Mon Sep 17 00:00:00 2001 From: Patrick Haller Date: Tue, 14 Feb 2023 15:26:41 +0100 Subject: [PATCH] Unify `load_from_cache_file` type and logic (#5515) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Updating type annotations for `load_from_cache_file` * Added logic for cache checking if needed * Updated documentation following the wording of `Dataset.map` Co-authored-by: Mario Šaško --- src/datasets/arrow_dataset.py | 28 +++++++++++++++++----------- src/datasets/dataset_dict.py | 16 ++++++++-------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 6d017dd6765f..0ee90d5b94ed 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1837,7 +1837,7 @@ def cast( features: Features, batch_size: Optional[int] = 1000, keep_in_memory: bool = False, - load_from_cache_file: bool = True, + load_from_cache_file: Optional[bool] = None, cache_file_name: Optional[str] = None, writer_batch_size: Optional[int] = 1000, num_proc: Optional[int] = None, @@ -2709,7 +2709,7 @@ def map( drop_last_batch: bool = False, remove_columns: Optional[Union[str, List[str]]] = None, keep_in_memory: bool = False, - load_from_cache_file: bool = None, + load_from_cache_file: Optional[bool] = None, cache_file_name: Optional[str] = None, writer_batch_size: Optional[int] = 1000, features: Optional[Features] = None, @@ -2768,7 +2768,7 @@ def map( columns with names in `remove_columns`, these columns will be kept. keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file. - load_from_cache_file (`bool`, defaults to `True` if caching is enabled): + load_from_cache_file (`Optioanl[bool]`, defaults to `True` if caching is enabled): If a cache file storing the current computation from `function` can be identified, use it instead of recomputing. cache_file_name (`str`, *optional*, defaults to `None`): @@ -3365,7 +3365,7 @@ def filter( batched: bool = False, batch_size: Optional[int] = 1000, keep_in_memory: bool = False, - load_from_cache_file: bool = True, + load_from_cache_file: Optional[bool] = None, cache_file_name: Optional[str] = None, writer_batch_size: Optional[int] = 1000, fn_kwargs: Optional[dict] = None, @@ -3399,7 +3399,7 @@ def filter( If `batch_size <= 0` or `batch_size == None`, provide the full dataset as a single batch to `function`. keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file. - load_from_cache_file (`bool`, defaults to `True`): + load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled): If a cache file storing the current computation from `function` can be identified, use it instead of recomputing. cache_file_name (`str`, *optional*): @@ -3807,7 +3807,7 @@ def sort( kind: str = None, null_placement: str = "last", keep_in_memory: bool = False, - load_from_cache_file: bool = True, + load_from_cache_file: Optional[bool] = None, indices_cache_file_name: Optional[str] = None, writer_batch_size: Optional[int] = 1000, new_fingerprint: Optional[str] = None, @@ -3833,7 +3833,7 @@ def sort( keep_in_memory (`bool`, defaults to `False`): Keep the sorted indices in memory instead of writing it to a cache file. - load_from_cache_file (`bool`, defaults to `True`): + load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled): If a cache file storing the sorted indices can be identified, use it instead of recomputing. indices_cache_file_name (`str`, *optional*, defaults to `None`): @@ -3872,6 +3872,8 @@ def sort( f"Column '{column}' not found in the dataset. Please provide a column selected in: {self._data.column_names}" ) + load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled() + # Check if we've already cached this computation (indexed by a hash) if self.cache_files: if indices_cache_file_name is None: @@ -3909,7 +3911,7 @@ def shuffle( seed: Optional[int] = None, generator: Optional[np.random.Generator] = None, keep_in_memory: bool = False, - load_from_cache_file: bool = True, + load_from_cache_file: Optional[bool] = None, indices_cache_file_name: Optional[str] = None, writer_batch_size: Optional[int] = 1000, new_fingerprint: Optional[str] = None, @@ -3957,7 +3959,7 @@ def shuffle( If `generator=None` (default), uses `np.random.default_rng` (the default BitGenerator (PCG64) of NumPy). keep_in_memory (`bool`, default `False`): Keep the shuffled indices in memory instead of writing it to a cache file. - load_from_cache_file (`bool`, defaults to `True`): + load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled): If a cache file storing the shuffled indices can be identified, use it instead of recomputing. indices_cache_file_name (`str`, *optional*): @@ -4002,6 +4004,8 @@ def shuffle( if generator is not None and not isinstance(generator, np.random.Generator): raise ValueError("The provided generator must be an instance of numpy.random.Generator") + load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled() + if generator is None: if seed is None: _, seed, pos, *_ = np.random.get_state() @@ -4046,7 +4050,7 @@ def train_test_split( seed: Optional[int] = None, generator: Optional[np.random.Generator] = None, keep_in_memory: bool = False, - load_from_cache_file: bool = True, + load_from_cache_file: Optional[bool] = None, train_indices_cache_file_name: Optional[str] = None, test_indices_cache_file_name: Optional[str] = None, writer_batch_size: Optional[int] = 1000, @@ -4083,7 +4087,7 @@ def train_test_split( If `generator=None` (default), uses `np.random.default_rng` (the default BitGenerator (PCG64) of NumPy). keep_in_memory (`bool`, defaults to `False`): Keep the splits indices in memory instead of writing it to a cache file. - load_from_cache_file (`bool`, defaults to `True`): + load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled): If a cache file storing the splits indices can be identified, use it instead of recomputing. train_cache_file_name (`str`, *optional*): @@ -4223,6 +4227,8 @@ def train_test_split( "aforementioned parameters." ) + load_from_cache_file = load_from_cache_file if load_from_cache_file is not None else is_caching_enabled() + if generator is None and shuffle is True: if seed is None: _, seed, pos, *_ = np.random.get_state() diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index ee8baaa106a1..03eede3b320b 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -756,7 +756,7 @@ def map( drop_last_batch: bool = False, remove_columns: Optional[Union[str, List[str]]] = None, keep_in_memory: bool = False, - load_from_cache_file: bool = True, + load_from_cache_file: Optional[bool] = None, cache_file_names: Optional[Dict[str, Optional[str]]] = None, writer_batch_size: Optional[int] = 1000, features: Optional[Features] = None, @@ -801,7 +801,7 @@ def map( columns with names in `remove_columns`, these columns will be kept. keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file. - load_from_cache_file (`bool`, defaults to `True`): + load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled): If a cache file storing the current computation from `function` can be identified, use it instead of recomputing. cache_file_names (`[Dict[str, str]]`, *optional*, defaults to `None`): @@ -881,7 +881,7 @@ def filter( batched: bool = False, batch_size: Optional[int] = 1000, keep_in_memory: bool = False, - load_from_cache_file: bool = True, + load_from_cache_file: Optional[bool] = None, cache_file_names: Optional[Dict[str, Optional[str]]] = None, writer_batch_size: Optional[int] = 1000, fn_kwargs: Optional[dict] = None, @@ -911,7 +911,7 @@ def filter( `batch_size <= 0` or `batch_size == None` then provide the full dataset as a single batch to `function`. keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file. - load_from_cache_file (`bool`, defaults to `True`): + load_from_cache_file (`Optional[bool]`, defaults to `True` if chaching is enabled): If a cache file storing the current computation from `function` can be identified, use it instead of recomputing. cache_file_names (`[Dict[str, str]]`, *optional*, defaults to `None`): @@ -982,7 +982,7 @@ def sort( kind: str = None, null_placement: str = "last", keep_in_memory: bool = False, - load_from_cache_file: bool = True, + load_from_cache_file: Optional[bool] = None, indices_cache_file_names: Optional[Dict[str, Optional[str]]] = None, writer_batch_size: Optional[int] = 1000, ) -> "DatasetDict": @@ -1008,7 +1008,7 @@ def sort( keep_in_memory (`bool`, defaults to `False`): Keep the sorted indices in memory instead of writing it to a cache file. - load_from_cache_file (`bool`, defaults to `True`): + load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled): If a cache file storing the sorted indices can be identified, use it instead of recomputing. indices_cache_file_names (`[Dict[str, str]]`, *optional*, defaults to `None`): @@ -1056,7 +1056,7 @@ def shuffle( seed: Optional[int] = None, generators: Optional[Dict[str, np.random.Generator]] = None, keep_in_memory: bool = False, - load_from_cache_file: bool = True, + load_from_cache_file: Optional[bool] = None, indices_cache_file_names: Optional[Dict[str, Optional[str]]] = None, writer_batch_size: Optional[int] = 1000, ) -> "DatasetDict": @@ -1081,7 +1081,7 @@ def shuffle( You have to provide one `generator` per dataset in the dataset dictionary. keep_in_memory (`bool`, defaults to `False`): Keep the dataset in memory instead of writing it to a cache file. - load_from_cache_file (`bool`, defaults to `True`): + load_from_cache_file (`Optional[bool]`, defaults to `True` if caching is enabled): If a cache file storing the current computation from `function` can be identified, use it instead of recomputing. indices_cache_file_names (`Dict[str, str]`, *optional*):