diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index 7a022e5635e674..d2a4b663801d78 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -606,6 +606,7 @@ def pipeline( # Retrieve the task if task in custom_tasks: + normalized_task = task targeted_task, task_options = clean_custom_task(custom_tasks[task]) if pipeline_class is None: if not trust_remote_code: diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 83474a5ba048e2..5d5c8fa2333eb6 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -795,7 +795,7 @@ def test_warning_logs(self): alias = "text-classification" # Get the original task, so we can restore it at the end. # (otherwise the subsequential tests in `TextClassificationPipelineTests` will fail) - original_task, original_task_options = PIPELINE_REGISTRY.check_task(alias) + _, original_task, _ = PIPELINE_REGISTRY.check_task(alias) try: with CaptureLogger(logger_) as cm: @@ -816,7 +816,7 @@ def test_register_pipeline(self): ) assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks() - task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification") + _, task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification") self.assertEqual(task_def["pt"], (AutoModelForSequenceClassification,) if is_torch_available() else ()) self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ()) self.assertEqual(task_def["type"], "text") diff --git a/utils/tests_fetcher.py b/utils/tests_fetcher.py index 9f18bb83c7ee7f..329d248de3c089 100644 --- a/utils/tests_fetcher.py +++ b/utils/tests_fetcher.py @@ -377,6 +377,7 @@ def create_reverse_dependency_map(): ], "optimization.py": "optimization/test_optimization.py", "optimization_tf.py": "optimization/test_optimization_tf.py", + "pipelines/__init__.py": "pipelines/test_pipelines_*.py", "pipelines/base.py": "pipelines/test_pipelines_*.py", "pipelines/text2text_generation.py": [ "pipelines/test_pipelines_text2text_generation.py",