Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for fp16 for asr pipeline. #20864

Merged
merged 10 commits into from
Dec 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,9 @@ def pipeline(
if feature_extractor is not None:
kwargs["feature_extractor"] = feature_extractor

if torch_dtype is not None:
kwargs["torch_dtype"] = torch_dtype

if device is not None:
kwargs["device"] = device

Expand Down
13 changes: 10 additions & 3 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ def rescale_stride(stride, ratio):
return new_strides


def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right):
def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
for i in range(0, inputs_len, step):
# add start and end paddings to the chunk
chunk = inputs[i : i + chunk_len]
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
if dtype is not None:
processed = processed.to(dtype=dtype)
_stride_left = 0 if i == 0 else stride_left
is_last = i + step + stride_left >= inputs_len
_stride_right = 0 if is_last else stride_right
Expand Down Expand Up @@ -240,6 +242,8 @@ def _sanitize_parameters(self, **kwargs):
preprocess_params["stride_length_s"] = kwargs["stride_length_s"]
if "ignore_warning" in kwargs:
preprocess_params["ignore_warning"] = kwargs["ignore_warning"]
if "torch_dtype" in kwargs:
preprocess_params["dtype"] = kwargs["torch_dtype"]

postprocess_params = {}
if "decoder_kwargs" in kwargs:
Expand All @@ -249,7 +253,7 @@ def _sanitize_parameters(self, **kwargs):

return preprocess_params, {}, postprocess_params

def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None):
if isinstance(inputs, str):
if inputs.startswith("http://") or inputs.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
Expand Down Expand Up @@ -332,12 +336,14 @@ def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warn
raise ValueError("Chunk length must be superior to stride length")

# make sure that
for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right):
for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right, dtype):
yield item
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if dtype is not None:
processed = processed.to(dtype=dtype)
if stride is not None:
if self.model.__class__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.values():
raise ValueError("Stride is only usable with CTC models, try removing it")
Expand Down Expand Up @@ -366,6 +372,7 @@ def _forward(self, model_inputs):
# `generate` magic to create the mask automatically won't work, we basically need to help
# it here.
attention_mask = model_inputs.pop("attention_mask", None)

tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
Expand Down
13 changes: 13 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,19 @@ def test_small_model_pt(self):
with self.assertRaisesRegex(ValueError, "^We cannot return_timestamps yet on non-ctc models !$"):
_ = speech_recognizer(waveform, return_timestamps="char")

@slow
@require_torch
def test_whisper_fp16(self):
if not torch.cuda.is_available():
self.skipTest("Cuda is necessary for this test")
speech_recognizer = pipeline(
model="openai/whisper-base",
device=0,
torch_dtype=torch.float16,
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
speech_recognizer(waveform)

@require_torch
def test_small_model_pt_seq2seq(self):
speech_recognizer = pipeline(
Expand Down