Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for fp16 for asr pipeline. #20864

Merged
merged 10 commits into from
Dec 23, 2022

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Dec 21, 2022

What does this PR do?

Fixes #20862

Many things were considered before settling for this design.

  • feature_extractor(return_tensors="pt¨, torch_dtype=torch_dtype) . This would have the advantage of being consistent, but not all feature extractors to define this, so it would affect all of them. Then why would we use torch_dtype instead of the more common place dtype which could be applied to TF and flax as well. Also it feels a bit redundant to specify both return_tensors and torch_dtype, it would be a good candidate to fuse both parameters (but outisde the scope of this PR).
  • AutoFeatureExtractor.from_pretrained(..., torch_dtype=torch_dtype). This would have the advantage of being overall so users don't need to respecify on each call. However we can't specifiy return_tensors="pt" in there either, so for consistency I didn't try to put it there.
  • ffmpeg_read(..., dtype=dtype) This would be nice to load directly the waveform into fp16 and just let fp16 flow through the feature_extractor. However, whisper in particular uses mel_spectrogram, so using f16 sound might actually damage performance.

In the end, this solution is the simplement I could come up with. Let torch_dtype flow to the pipeline, use it as a regular parameter and convert the output of the feature_extractor after.

This does incur a potentially extra copy but there's no risk of damaging quality of the input.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 21, 2022

The documentation is not available anymore as the PR was closed or merged.

accept torch_dtype=fp16).

Also we need to use a GPU to actually compute on fp16.
yield item
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if dtype is not None:
processed = {k: v.to(dtype=dtype) for k, v in processed.items()}
Copy link
Contributor

@bofenghuang bofenghuang Dec 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Narsil,

I think this works fine for whisper models because they only have a single value input_features.

But in case of other models like wav2vec2, the model have multiple values of different dtypes, input_values which need to be casted from float32 to float16, and attention_mask I'm not sure to keep as int32 or cast to int16

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. And as above, if you directly use the to method on processed, it will take care of that for you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks, TIL

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! My only comment is to make sure to leverage the to method on BatchFeature (if the feature extractor here returns another type, maybe make sure its to method handles dtype arguments) so that checks like not converting int inputs are applied for free.

Otherwise LGTM!

inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
for i in range(0, inputs_len, step):
# add start and end paddings to the chunk
chunk = inputs[i : i + chunk_len]
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
if dtype is not None:
processed = {k: v.to(dtype=dtype) for k, v in processed.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you can call the to directly on processed, which is a BatchFeature and handles dtype in its to method thanks to #20536 (was designed for vision but I think it will apply here too).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

@@ -249,7 +253,8 @@ def _sanitize_parameters(self, **kwargs):

return preprocess_params, {}, postprocess_params

def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False):
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None):
print(f"Running with dtype {dtype}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be cleaned up ;-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops

yield item
else:
processed = self.feature_extractor(
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
if dtype is not None:
processed = {k: v.to(dtype=dtype) for k, v in processed.items()}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. And as above, if you directly use the to method on processed, it will take care of that for you.

@Narsil Narsil merged commit f7f0ec2 into huggingface:main Dec 23, 2022
@Narsil Narsil deleted the support_fp16_asr branch December 23, 2022 09:18
MKhalusova pushed a commit to MKhalusova/transformers that referenced this pull request Dec 28, 2022
* Supporting `fp16` for asr pipeline

* Adding test.

* Style.

* Oops.

* Flake8 update ?

* Fixing flake8 ?

* Revert "Flake8 update ?"

This reverts commit 0b917fc.

* Style (acctidentally deleted flake8 F401.)

* Move to a bigger test (no small whisper model, and s2t doesn't seem to
accept torch_dtype=fp16).

Also we need to use a GPU to actually compute on fp16.

* Using BatchFeature capability.
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jan 4, 2023
* Supporting `fp16` for asr pipeline

* Adding test.

* Style.

* Oops.

* Flake8 update ?

* Fixing flake8 ?

* Revert "Flake8 update ?"

This reverts commit 0b917fc.

* Style (acctidentally deleted flake8 F401.)

* Move to a bigger test (no small whisper model, and s2t doesn't seem to
accept torch_dtype=fp16).

Also we need to use a GPU to actually compute on fp16.

* Using BatchFeature capability.
silverriver pushed a commit to silverriver/transformers that referenced this pull request Jan 6, 2023
* Supporting `fp16` for asr pipeline

* Adding test.

* Style.

* Oops.

* Flake8 update ?

* Fixing flake8 ?

* Revert "Flake8 update ?"

This reverts commit 0b917fc.

* Style (acctidentally deleted flake8 F401.)

* Move to a bigger test (no small whisper model, and s2t doesn't seem to
accept torch_dtype=fp16).

Also we need to use a GPU to actually compute on fp16.

* Using BatchFeature capability.
venkat-natchi pushed a commit to venkat-natchi/transformers that referenced this pull request Jan 22, 2023
* Supporting `fp16` for asr pipeline

* Adding test.

* Style.

* Oops.

* Flake8 update ?

* Fixing flake8 ?

* Revert "Flake8 update ?"

This reverts commit 0b917fc.

* Style (acctidentally deleted flake8 F401.)

* Move to a bigger test (no small whisper model, and s2t doesn't seem to
accept torch_dtype=fp16).

Also we need to use a GPU to actually compute on fp16.

* Using BatchFeature capability.
miyu386 pushed a commit to miyu386/transformers that referenced this pull request Feb 9, 2023
* Supporting `fp16` for asr pipeline

* Adding test.

* Style.

* Oops.

* Flake8 update ?

* Fixing flake8 ?

* Revert "Flake8 update ?"

This reverts commit 0b917fc.

* Style (acctidentally deleted flake8 F401.)

* Move to a bigger test (no small whisper model, and s2t doesn't seem to
accept torch_dtype=fp16).

Also we need to use a GPU to actually compute on fp16.

* Using BatchFeature capability.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Run AutomaticSpeechRecognitionPipeline with FP16
4 participants