From b65660b7c6e853391991734210e38f805459b0ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20=C5=A0a=C5=A1ko?= Date: Mon, 10 Jul 2023 14:24:01 +0200 Subject: [PATCH] Deprecate task api (#5865) * Deprecate Task API * Typo * Update task_templates.mdx * Update task_templates.mdx --- .../package_reference/task_templates.mdx | 6 ++++++ src/datasets/arrow_dataset.py | 2 ++ src/datasets/dataset_dict.py | 2 ++ src/datasets/load.py | 21 ++++++++++++++++--- 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/docs/source/package_reference/task_templates.mdx b/docs/source/package_reference/task_templates.mdx index 52d275b8531..d07566590e3 100644 --- a/docs/source/package_reference/task_templates.mdx +++ b/docs/source/package_reference/task_templates.mdx @@ -1,5 +1,11 @@ # Task templates + + +The Task API is deprecated in favor of [`train-eval-index`](https://github.com/huggingface/hub-docs/blob/9ab2555e1c146122056aba6f89af404a8bc9a6f1/datasetcard.md?plain=1#L90-L106) and will be removed in the next major release. + + + The tasks supported by [`Dataset.prepare_for_task`] and [`DatasetDict.prepare_for_task`]. [[autodoc]] datasets.tasks.AutomaticSpeechRecognition diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 4ce83f584c6..06bbc178916 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -109,6 +109,7 @@ ) from .tasks import TaskTemplate from .utils import logging +from .utils.deprecation_utils import deprecated from .utils.file_utils import _retry, cached_path, estimate_dataset_size from .utils.hub import hf_hub_url from .utils.info_utils import is_small_dataset @@ -2706,6 +2707,7 @@ def with_transform( dataset.set_transform(transform=transform, columns=columns, output_all_columns=output_all_columns) return dataset + @deprecated() def prepare_for_task(self, task: Union[str, TaskTemplate], id: int = 0) -> "Dataset": """ Prepare a dataset for the given task by casting the dataset's [`Features`] to standardized column names and types as detailed in [`datasets.tasks`](./task_templates). diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 86e2ee99f5f..6c2b6211f53 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -25,6 +25,7 @@ from .table import Table from .tasks import TaskTemplate from .utils import logging +from .utils.deprecation_utils import deprecated from .utils.doc_utils import is_documented_by from .utils.file_utils import cached_path from .utils.hub import hf_hub_url @@ -1537,6 +1538,7 @@ def from_text( path_or_paths, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs ).read() + @deprecated() @is_documented_by(Dataset.prepare_for_task) def prepare_for_task(self, task: Union[str, TaskTemplate], id: int = 0) -> "DatasetDict": self._check_values_type() diff --git a/src/datasets/load.py b/src/datasets/load.py index 05979a6438f..88f9b05fe13 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -63,7 +63,6 @@ _hash_python_lines, ) from .splits import Split -from .tasks import TaskTemplate from .utils.deprecation_utils import deprecated from .utils.file_utils import ( OfflineModeIsEnabled, @@ -1586,7 +1585,7 @@ def load_dataset( revision: Optional[Union[str, Version]] = None, token: Optional[Union[bool, str]] = None, use_auth_token="deprecated", - task: Optional[Union[str, TaskTemplate]] = None, + task="deprecated", streaming: bool = False, num_proc: Optional[int] = None, storage_options: Optional[Dict] = None, @@ -1708,6 +1707,12 @@ def load_dataset( task (`str`): The task to prepare the dataset for during training and evaluation. Casts the dataset's [`Features`] to standardized column names and types as detailed in `datasets.tasks`. + + + + `task` was deprecated in version 2.13.0 and will be removed in 3.0.0. + + streaming (`bool`, defaults to `False`): If set to `True`, don't download the data files. Instead, it streams the data progressively while iterating on the dataset. An [`IterableDataset`] or [`IterableDatasetDict`] is returned instead in this case. @@ -1795,6 +1800,13 @@ def load_dataset( f"You can remove this warning by passing 'verification_mode={verification_mode.value}' instead.", FutureWarning, ) + if task != "deprecated": + warnings.warn( + "'task' was deprecated in version 2.13.0 and will be removed in 3.0.0.\n", + FutureWarning, + ) + else: + task = None if data_files is not None and not data_files: raise ValueError(f"Empty 'data_files': '{data_files}'. It should be either non-empty or None (default).") if Path(path, config.DATASET_STATE_JSON_FILENAME).exists(): @@ -1855,7 +1867,10 @@ def load_dataset( ds = builder_instance.as_dataset(split=split, verification_mode=verification_mode, in_memory=keep_in_memory) # Rename and cast features to match task schema if task is not None: - ds = ds.prepare_for_task(task) + # To avoid issuing the same warning twice + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + ds = ds.prepare_for_task(task) if save_infos: builder_instance._save_infos()