Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add return_tensor parameter for feature extraction #19257

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/transformers/pipelines/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class FeatureExtractionPipeline(Pipeline):
If no framework is specified, will default to the one currently installed. If no framework is specified and
both frameworks are installed, will default to the framework of the `model`, or to PyTorch if no model is
provided.
return_tensor (`bool`, *optional*):
If true, returns a tensor according to the specified framework. Else, it returns a list.
ajsanjoaquin marked this conversation as resolved.
Show resolved Hide resolved
task (`str`, defaults to `""`):
A task-identifier for the pipeline.
args_parser ([`~pipelines.ArgumentHandler`], *optional*):
Expand All @@ -40,7 +42,7 @@ class FeatureExtractionPipeline(Pipeline):
the associated CUDA device id.
"""

def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, **kwargs):
def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, return_tensors=None, **kwargs):
if tokenize_kwargs is None:
tokenize_kwargs = {}

Expand All @@ -53,7 +55,11 @@ def _sanitize_parameters(self, truncation=None, tokenize_kwargs=None, **kwargs):

preprocess_params = tokenize_kwargs

return preprocess_params, {}, {}
postprocess_params = {}
if return_tensors is not None:
postprocess_params["return_tensors"] = return_tensors

return preprocess_params, {}, postprocess_params

def preprocess(self, inputs, **tokenize_kwargs) -> Dict[str, GenericTensor]:
return_tensors = self.framework
Expand All @@ -64,8 +70,10 @@ def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs

def postprocess(self, model_outputs):
def postprocess(self, model_outputs, return_tensors=False):
# [0] is the first available tensor, logits or last_hidden_state.
if return_tensors:
return model_outputs[0]
if self.framework == "pt":
return model_outputs[0].tolist()
elif self.framework == "tf":
Expand Down
18 changes: 18 additions & 0 deletions tests/pipelines/test_pipelines_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import unittest

import numpy as np
import tensorflow as tf
import torch

from transformers import (
FEATURE_EXTRACTOR_MAPPING,
Expand Down Expand Up @@ -133,6 +135,22 @@ def test_tokenization_small_model_tf(self):
tokenize_kwargs=tokenize_kwargs,
)

@require_torch
def test_return_tensors_pt(self):
feature_extractor = pipeline(
task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert", framework="pt"
)
outputs = feature_extractor("This is a test" * 100, return_tensors=True)
self.assertTrue(torch.is_tensor(outputs))

@require_tf
def test_return_tensors_tf(self):
feature_extractor = pipeline(
task="feature-extraction", model="hf-internal-testing/tiny-random-distilbert", framework="tf"
)
outputs = feature_extractor("This is a test" * 100, return_tensors=True)
self.assertTrue(tf.is_tensor(outputs))

def get_shape(self, input_, shape=None):
if shape is None:
shape = []
Expand Down