Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add implementation of spectrogram_batch #27159

Merged
merged 16 commits into from
Jun 24, 2024

Conversation

ravenouse
Copy link
Contributor

@ravenouse ravenouse commented Oct 30, 2023

What does this PR do?

This pull request introduces the implementation of spectrogram_batch, specifically optimized for batch processing through broadcasting techniques. The primary goal is to reduce the data processing time during feature extraction, which is a critical step in working with audio models like whisper.

Motivation

In my work and research with the whisper model, I observed that the feature extraction step can be exceedingly time-consuming, taking up to 10 hours for certain audio datasets.

In my opinion, the bottleneck is primarily due to the lack of batch processing support in the current spectrogram and FeatureExtractor implementations, resulting in iterative calls within a for-loop, as illustrated below:

# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py#L250
input_features = [self._np_extract_fbank_features(waveform) for waveform in input_features[0]]

Future Work

The current branch only adds a basic implementation of the spectrogram_batch.
To bring this implementation to production level, I believe there are several steps needed to be done:

  1. Extensive Testing: Implementing a comprehensive suite of tests to evaluate the function’s performance and correctness across different parameter settings. I only tested the implementation in the whisper models' setting.
  2. Integration: Modifying existing feature extractor codes to incorporate the new batch processing function.

I am fully committed to continuing the development and testing of this feature. However, given the extensive changes and the potential impact on the library, I am reaching out for collaboration and support from the community/maintainers. Any guidance, suggestions, or contributions to this effort would be immensely appreciated.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sanchit-gandhi
@ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! This looks promising and I think would be welcome! We had to go back to using torchaudio for some processing specifically because of the overhead.
I think something to look forward too is also chunking single long file to batches for even faster pre-processings!

FYI @ylacombe

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@sanchit-gandhi
Copy link
Contributor

Super cool PR @ravenouse, a most welcome upgrade! As mentioned by @ArthurZucker, one other optimisation we can employ is using the torch.stft backend for computing the spectrograms: #26119 (comment). This yields approx 4x speed-up for processing with bs=1. I believe @ylacombe is considering adding support for this, which could be a nice parallel to your PR!

@ylacombe
Copy link
Contributor

ylacombe commented Nov 6, 2023

Thanks for working on this @ravenouse, this looks super promising and will clearly be valuable for any audio-related model!

I'm indeed considering adding a torch alternative of this numpy implementation! What would be really good for these current/future improvements is that we conduct extensive speed benchmark to allow users to make an informed choice when choosing implementation!

@ravenouse
Copy link
Contributor Author

Thank you so much for all the feedbacks and information! @ArthurZucker @sanchit-gandhi @ylacombe

I am excited to know the 4x speed up brought by the torch.stft.
Inspired by the experiment from @sanchit-gandhi , I conducted a similar experiment for spectrogram_batch, resulting in a 2x speedup when bs=200, compared to the original function with bs=1.
Link to the experimenting notebook: https://colab.research.google.com/drive/1aXytDfXiMy_tzvjP9A4rM7Z24jmV2Ha-?usp=sharing

For further enhancement, I believe implementing code that enables GPU acceleration for feature extraction and providing users with the option to select GPUs would be an incredible step forward. Prior to submitting this PR, I experimented with a CuPy version of spectrogram_batch. My initial findings indicate that the CuPy Batch version is even faster than the Numpy Batch version, if the GPU memory is managed effectively.
I anticipate that the torch GPU version can achieve comparable performance. An experimental notebook exploring this aspect will be shared shortly.

Once again, I am more than happy to contribute to the package in this direction. Please let me know if there is anything else I can do to further the effort.

@huggingface huggingface deleted a comment from github-actions bot Dec 1, 2023
@ArthurZucker
Copy link
Collaborator

cc @ylacombe and @sanchit-gandhi let's try to maybe merge this now and have the rest as a todo?

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ravenouse,
First of all, thank you for taking such a useful initiative, and for your time and effort, it's already a great PR! Also sorry for taking such a long time to review!

Before going any further, I think we need to discuss several points!

As you rightfully pointed out, the current implementation process sample one-by-one resulting in a sub-optimal speed.
However and at the moment, your current implementation expects and works with a batch of similar size inputs. But how to deal with inputs of different sizes ?

From your code, I'm under the impression that we first have to pad the inputs, then post-process outside of the function to get the right results. Is that correct?
In that case, does it give the expected results? In other words, do we get the same results by processing the samples one-by-one versus processing the samples by batch? With your current implementation and tests, we cannot know!

IMO, we should do this pre and post-processing directly in the spectrogram_batch. So I would expect:

  • spectrogram_batch to take as input a list of numpy array (i.e a list of 1D audio waveform) and to output a list of spectrogram i.e a list of different-size spectrograms
  • spectrogram_batch to pad inputs, then do batch processing (while keeping track of the output sizes), and finally to output list of different-size spectrograms by truncating spectrograms one by one according to the corresponding output size.
  • tests that make sure that inputs (random or not) with different size passed to spectrogram_batch to get the same results as compared to pass process them one-by-one

Once that's done, we'll be sure that the function behaves as expected!

Again, thank you for your effort here, let me know if I can be of any help here!

db_range: Optional[float] = None,
remove_dc_offset: Optional[bool] = None,
dtype: np.dtype = np.float32,
) -> np.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're missing a docstrings here, similar to the one of spectrogram

Comment on lines 575 to 577
spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
elif power == 2.0:
spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this two functions also need to be adapted for batching!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you have time to look into this ? Seems unresolved

@ravenouse
Copy link
Contributor Author

Hi @ylacombe ,
Thank you so much for taking the time to review this PR. I really appreciate the insightful inputs you have shared.

I will definitely follow the suggested directions, ensuring that the spectrogram_batch generates the same results as the original function.

Currently, I am working towards a deadline this week. Once that is completed, I will prioritize making the necessary modifications to the PR.

Once again, thank you for your valuable input and support!

@ravenouse
Copy link
Contributor Author

ravenouse commented Dec 28, 2023

Hi @ylacombe,

I hope you had a wonderful break!

Following your feedback, I've updated the spectrogram_batch function to enhance its pre and post-processing capabilities. The key modifications include:

  • Utilizing original_waveform_lengths and true_num_frames to capture and truncate the results to their true lengths.
  • Currently, the function relies on the pad function of the SequenceFeatureExtractor, as detailed here. This results in some redundancy in the code.

You can review the revised function in this Colab notebook. It produces the same results as the original spectrogram when tested on hf-internal-testing/librispeech_asr_dummy.

For the next step, I plan to:

  1. Implement a simple and efficient batch padding method, eliminating the current reliance on SequenceFeatureExtractor..
  2. Implement batch version for the amplitude_to_db and power_to_db.
  3. Add the function annotation and docstrings to the functions.

Please let me know what your thoughts on this.
Thank you very much!

@ylacombe
Copy link
Contributor

ylacombe commented Jan 1, 2024

Hey @ravenouse, thanks for all the progress so far and happy new year! Could you actually update the code here so that it's easier to review, test and keep track of the comments?
Many thanks!

@ravenouse
Copy link
Contributor Author

Hi @ylacombe, happy new year!

I have updated the code: I modified the spectrogram_batch further, eliminating the previous dependency of the SequenceFeatureExtractor for batch padding the waveforms.

In case you want to run the test on your own, here is the updated notebook: Link

Please let me know what you think about this!

Thank you so much for your time!

@sanchit-gandhi
Copy link
Contributor

Gently pining @ylacombe

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this Feb 28, 2024
@ArthurZucker ArthurZucker reopened this Feb 29, 2024
Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ravenouse, again sorry for the wait!

This is definitely a great improvement and looks really great!
Your modeling code looks great, I'd focus in the next steps on:

  • adapting the tests to your new implementation. Remember that we want to test that different waveforms with different lengths produce the same results in batch or sequentially.
  • add docstrings

Testing extensively will also allow to pinpoint some issues that you and I may have overlooked while writing/reviewing your code (e.g I'm not totally sure it works with amplitude and power spectrogram!

Again, thanks for the great work!

src/transformers/audio_utils.py Show resolved Hide resolved
Comment on lines 527 to 535
padding = [(int(frame_length // 2), int(frame_length // 2))]
padded_waveform_list = [
np.pad(
waveform,
padding,
mode=pad_mode,
)
for waveform in waveform_list
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, shouldn't we have padding adapted for each waveform of the list ?

Comment on lines 542 to 548
padded_waveform_batch = np.array(
[
np.pad(waveform, (0, max_length - len(waveform)), mode="constant", constant_values=0)
for waveform in padded_waveform_list
],
dtype=dtype,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too bad we don't have a numpy equivalent to torch.nn.utils.rnn.pad_sequence

Comment on lines 575 to 577
spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)
elif power == 2.0:
spectrogram = power_to_db(spectrogram, reference, min_value, db_range)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you have time to look into this ? Seems unresolved


spectrogram = np.asarray(spectrogram, dtype)

spectrogram_list = [spectrogram[i, : true_num_frames[i], :].T for i in range(len(true_num_frames))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've got the exact same comment than last time:

  • tests that make sure that inputs (random or not) with different size passed to spectrogram_batch to get the same results as compared to pass process them one-by-one

We want waveforms of different lengths to produce the same results, whether transmitted in batches or one at a time.

@github-actions github-actions bot closed this Mar 9, 2024
@ravenouse
Copy link
Contributor Author

Hi @ylacombe,

Thank you so much for your input!

I will address the problems and directions you've highlighted, focusing specifically on:

  1. Testings,
  2. Implementations of the batch versions for amplitude_to_db and power_to_db.

Thank you again!

@ravenouse
Copy link
Contributor Author

ravenouse commented Mar 18, 2024

Hi @ylacombe,

Good morning!

Could you please reopen this PR?

I will work actively on this PR: I am aiming to finish the above-mentioned problems/to-dos this week.

Thank you so much!

@ylacombe
Copy link
Contributor

Hey @ravenouse, of course!

@ylacombe ylacombe reopened this Mar 18, 2024
@ravenouse
Copy link
Contributor Author

Hi @ylacombe,

Good day!

I wanted to provide a quick update: the testing phase is nearly complete.
I am in the process of organizing the function names and preparing the notebook to demonstrate how the expected values used in the tests are generated.

This PR will be ready for your review by next Monday.

Thank you so much for all your support!

@ravenouse
Copy link
Contributor Author

ravenouse commented Mar 26, 2024

Hi @ylacombe,

I believe this PR is now ready for your review!
To facilitate the review process, I've provided a link to a Colab notebook below. This notebook demonstrates the methodology used to generate the expected values used in the test functions.
Link: https://colab.research.google.com/drive/1bzYz8MydhuHUAIZauzHjrjMQ_Of89ZAj?usp=sharing

Please let me know your thoughts on this and if there is anything I should work on further!

Thank you so much!

@ravenouse ravenouse changed the title [WIP] Add implementation of spectrogram_batch Add implementation of spectrogram_batch Mar 27, 2024
@ylacombe
Copy link
Contributor

ylacombe commented Apr 1, 2024

Hey @ravenouse, thanks for updating the code and for the Colab, I'll take a look ASAP!

@huggingface huggingface deleted a comment from github-actions bot Apr 25, 2024
@amyeroberts
Copy link
Collaborator

Gentle ping @ylacombe

@huggingface huggingface deleted a comment from github-actions bot May 20, 2024
@huggingface huggingface deleted a comment from github-actions bot Jun 14, 2024
@amyeroberts
Copy link
Collaborator

Another ping @ylacombe

Copy link
Contributor

@ylacombe ylacombe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ravenouse, I hope you're doing well!
First of all, I'm really sorry for the late review and thank you so much for working on this!

This looks great to me as it is! Let's ask for a core maintainer review!

In the meantime, I believe we should:

  1. re-add the test that has been removed by mistake
  2. conduct a speed benchmark as compared to the o.g method

Again, thanks for the great work here, this is a beautiful PR!

Comment on lines +884 to +886
# Apply db_range clipping per batch item
max_values = spectrogram.max(axis=(1, 2), keepdims=True)
spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking clean 🔥 !

Comment on lines +975 to +980
if db_range is not None:
if db_range <= 0.0:
raise ValueError("db_range must be greater than zero")
# Apply db_range clipping per batch item
max_values = spectrogram.max(axis=(1, 2), keepdims=True)
spectrogram = np.clip(spectrogram, a_min=max_values - db_range, a_max=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great as well!

Comment on lines 33 to 37
from transformers.testing_utils import is_librosa_available


if is_librosa_available():
from librosa.filters import chroma
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we checking if librosa is available if we don't import anything at the end of the day ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you removed a test by mistake !

Comment on lines 765 to 766
@require_librosa
def test_chroma_equivalence(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you removed this test, probably a mistake right ? In that case, let's put it back and reimport librosa !

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for adding these thorough tests, looks like the method works as expected now!

@ravenouse
Copy link
Contributor Author

Hi @ylacombe,

Good day! Thank you very much for your careful reviews on this PR.

I have restored the mistakenly deleted function and created a notebook to demonstrate the speed improvements for the batch functions. You can find the notebook below:
https://colab.research.google.com/drive/1UuSx7Apa6FRlO1DE6QwnZbboB2KWTQbm?usp=sharing

Please let me know if there is anything else I can do to facilitate this PR.
Thank you so much for all your time and help on this!

@ylacombe
Copy link
Contributor

Benchmarks are looking great, thanks for the effort!

cc @amyeroberts, could you review this PR ?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 🔥 !

Very clean and easy to follow code - thanks for adding in such extensive tests too ❤️

src/transformers/audio_utils.py Outdated Show resolved Hide resolved
Comment on lines +713 to +715
original_waveform_lengths = [
len(waveform) for waveform in waveform_list
] # these lengths will be used to remove padding later
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - library convention is comments on the line above to avoid line splitting like this

Suggested change
original_waveform_lengths = [
len(waveform) for waveform in waveform_list
] # these lengths will be used to remove padding later
# these lengths will be used to remove padding later
original_waveform_lengths = [len(waveform) for waveform in waveform_list]

src/transformers/audio_utils.py Outdated Show resolved Hide resolved
ravenouse and others added 2 commits June 20, 2024 13:55
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@ylacombe
Copy link
Contributor

Thanks again for this @ravenouse, let's discuss on how to integrate this on some feature extractors as a follow-up? Would you like to give it a try?

@ravenouse
Copy link
Contributor Author

ravenouse commented Jun 21, 2024

Hi @ylacombe and @amyeroberts,

Thank you for your careful reviews and kind feedback.

Yes, I would love to work more on this!

Since the spectrogram_batch function is a faster equivalent to the spectrogram function, which has been used in many audio models' feature extraction parts, I believe we can start by replacing spectrogram in the batch settings. Having worked with whisper and seamless_m4t models before, I propose starting with these models.

Namely, I plan to implement a _extract_fbank_features_batch class method for the models, which will mimic the _extract_fbank_features class method but optimized to avoid for loops. For example, the _extract_fbank_features_batchwill replace the for loop here, line 259 in the seamless_m4t's feature extraction code.

What do you think about this plan?

Thank you so much again!

Edit: Removed whisper from the plan list since it already has the _torch_extract_fbank_features function.

@ylacombe ylacombe merged commit dce253f into huggingface:main Jun 24, 2024
19 checks passed
@ylacombe
Copy link
Contributor

Hey @ravenouse, thanks for following that up!
The plan sounds good!
There's only a few models that are using _extract_fbank_features, let's do them all at once?
If you feel more comfortable starting with Seamless M4T, don't hesitate as well. Don't hesitate to ask for a review early on as well!

zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jun 24, 2024
* Add initial implementation of `spectrogram_batch`

* Format the initial implementation

* Add test suite for the `spectrogram_batch`

* Update `spectrogram_batch` to ensure compatibility with test suite

* Update `spectrogram_batch` to include pre and post-processing

* Add `amplitude_to_db_batch` function and associated tests

* Add `power_to_db_batch` function and associated tests

* Reimplement the test suite for `spectrogram_batch`

* Fix errors in `spectrogram_batch`

* Add the function annotation for `spectrogram_batch`

* Address code quality

* Re-add `test_chroma_equivalence` function

* Update src/transformers/audio_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/audio_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@ravenouse
Copy link
Contributor Author

Hi @ylacombe , Sounds great!
I will create a new PR for this soon!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants