-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for fp16
for asr pipeline.
#20864
Conversation
The documentation is not available anymore as the PR was closed or merged. |
accept torch_dtype=fp16). Also we need to use a GPU to actually compute on fp16.
yield item | ||
else: | ||
processed = self.feature_extractor( | ||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" | ||
) | ||
if dtype is not None: | ||
processed = {k: v.to(dtype=dtype) for k, v in processed.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @Narsil,
I think this works fine for whisper models because they only have a single value input_features
.
But in case of other models like wav2vec2, the model have multiple values of different dtypes, input_values
which need to be casted from float32
to float16
, and attention_mask
I'm not sure to keep as int32
or cast to int16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. And as above, if you directly use the to
method on processed
, it will take care of that for you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks, TIL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! My only comment is to make sure to leverage the to
method on BatchFeature
(if the feature extractor here returns another type, maybe make sure its to
method handles dtype arguments) so that checks like not converting int inputs are applied for free.
Otherwise LGTM!
inputs_len = inputs.shape[0] | ||
step = chunk_len - stride_left - stride_right | ||
for i in range(0, inputs_len, step): | ||
# add start and end paddings to the chunk | ||
chunk = inputs[i : i + chunk_len] | ||
processed = feature_extractor(chunk, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") | ||
if dtype is not None: | ||
processed = {k: v.to(dtype=dtype) for k, v in processed.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe you can call the to
directly on processed
, which is a BatchFeature
and handles dtype
in its to
method thanks to #20536 (was designed for vision but I think it will apply here too).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done !
@@ -249,7 +253,8 @@ def _sanitize_parameters(self, **kwargs): | |||
|
|||
return preprocess_params, {}, postprocess_params | |||
|
|||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False): | |||
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None, ignore_warning=False, dtype=None): | |||
print(f"Running with dtype {dtype}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be cleaned up ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops
yield item | ||
else: | ||
processed = self.feature_extractor( | ||
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" | ||
) | ||
if dtype is not None: | ||
processed = {k: v.to(dtype=dtype) for k, v in processed.items()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. And as above, if you directly use the to
method on processed
, it will take care of that for you.
* Supporting `fp16` for asr pipeline * Adding test. * Style. * Oops. * Flake8 update ? * Fixing flake8 ? * Revert "Flake8 update ?" This reverts commit 0b917fc. * Style (acctidentally deleted flake8 F401.) * Move to a bigger test (no small whisper model, and s2t doesn't seem to accept torch_dtype=fp16). Also we need to use a GPU to actually compute on fp16. * Using BatchFeature capability.
* Supporting `fp16` for asr pipeline * Adding test. * Style. * Oops. * Flake8 update ? * Fixing flake8 ? * Revert "Flake8 update ?" This reverts commit 0b917fc. * Style (acctidentally deleted flake8 F401.) * Move to a bigger test (no small whisper model, and s2t doesn't seem to accept torch_dtype=fp16). Also we need to use a GPU to actually compute on fp16. * Using BatchFeature capability.
* Supporting `fp16` for asr pipeline * Adding test. * Style. * Oops. * Flake8 update ? * Fixing flake8 ? * Revert "Flake8 update ?" This reverts commit 0b917fc. * Style (acctidentally deleted flake8 F401.) * Move to a bigger test (no small whisper model, and s2t doesn't seem to accept torch_dtype=fp16). Also we need to use a GPU to actually compute on fp16. * Using BatchFeature capability.
* Supporting `fp16` for asr pipeline * Adding test. * Style. * Oops. * Flake8 update ? * Fixing flake8 ? * Revert "Flake8 update ?" This reverts commit 0b917fc. * Style (acctidentally deleted flake8 F401.) * Move to a bigger test (no small whisper model, and s2t doesn't seem to accept torch_dtype=fp16). Also we need to use a GPU to actually compute on fp16. * Using BatchFeature capability.
* Supporting `fp16` for asr pipeline * Adding test. * Style. * Oops. * Flake8 update ? * Fixing flake8 ? * Revert "Flake8 update ?" This reverts commit 0b917fc. * Style (acctidentally deleted flake8 F401.) * Move to a bigger test (no small whisper model, and s2t doesn't seem to accept torch_dtype=fp16). Also we need to use a GPU to actually compute on fp16. * Using BatchFeature capability.
What does this PR do?
Fixes #20862
Many things were considered before settling for this design.
feature_extractor(return_tensors="pt¨, torch_dtype=torch_dtype)
. This would have the advantage of being consistent, but not all feature extractors to define this, so it would affect all of them. Then why would we usetorch_dtype
instead of the more common placedtype
which could be applied to TF and flax as well. Also it feels a bit redundant to specify bothreturn_tensors
andtorch_dtype
, it would be a good candidate to fuse both parameters (but outisde the scope of this PR).AutoFeatureExtractor.from_pretrained(..., torch_dtype=torch_dtype)
. This would have the advantage of being overall so users don't need to respecify on each call. However we can't specifiyreturn_tensors="pt"
in there either, so for consistency I didn't try to put it there.ffmpeg_read(..., dtype=dtype)
This would be nice to load directly the waveform into fp16 and just let fp16 flow through the feature_extractor. However, whisper in particular uses mel_spectrogram, so using f16 sound might actually damage performance.In the end, this solution is the simplement I could come up with. Let
torch_dtype
flow to the pipeline, use it as a regular parameter and convert the output of the feature_extractor after.This does incur a potentially extra copy but there's no risk of damaging quality of the input.
Fixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.