Skip to content

Commit

Permalink
Fix pipeline tests (#18487)
Browse files Browse the repository at this point in the history
* Fix pipeline tests

* Make sure all pipelines tests run with init changes
  • Loading branch information
sgugger authored Aug 5, 2022
1 parent c7849d9 commit 70fa1a8
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ def pipeline(

# Retrieve the task
if task in custom_tasks:
normalized_task = task
targeted_task, task_options = clean_custom_task(custom_tasks[task])
if pipeline_class is None:
if not trust_remote_code:
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def test_warning_logs(self):
alias = "text-classification"
# Get the original task, so we can restore it at the end.
# (otherwise the subsequential tests in `TextClassificationPipelineTests` will fail)
original_task, original_task_options = PIPELINE_REGISTRY.check_task(alias)
_, original_task, _ = PIPELINE_REGISTRY.check_task(alias)

try:
with CaptureLogger(logger_) as cm:
Expand All @@ -816,7 +816,7 @@ def test_register_pipeline(self):
)
assert "custom-text-classification" in PIPELINE_REGISTRY.get_supported_tasks()

task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
_, task_def, _ = PIPELINE_REGISTRY.check_task("custom-text-classification")
self.assertEqual(task_def["pt"], (AutoModelForSequenceClassification,) if is_torch_available() else ())
self.assertEqual(task_def["tf"], (TFAutoModelForSequenceClassification,) if is_tf_available() else ())
self.assertEqual(task_def["type"], "text")
Expand Down
1 change: 1 addition & 0 deletions utils/tests_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def create_reverse_dependency_map():
],
"optimization.py": "optimization/test_optimization.py",
"optimization_tf.py": "optimization/test_optimization_tf.py",
"pipelines/__init__.py": "pipelines/test_pipelines_*.py",
"pipelines/base.py": "pipelines/test_pipelines_*.py",
"pipelines/text2text_generation.py": [
"pipelines/test_pipelines_text2text_generation.py",
Expand Down

0 comments on commit 70fa1a8

Please sign in to comment.