diff --git a/examples/tutorials/audio_resampling_tutorial.py b/examples/tutorials/audio_resampling_tutorial.py index 65e2a313e8a..de840b6b610 100644 --- a/examples/tutorials/audio_resampling_tutorial.py +++ b/examples/tutorials/audio_resampling_tutorial.py @@ -3,14 +3,9 @@ Audio Resampling ================ -Here, we will walk through resampling audio waveforms using ``torchaudio``. - +This tutorial shows how to use torchaudio's resampling API. """ -# When running this tutorial in Google Colab, install the required packages -# with the following. -# !pip install torchaudio librosa - import torch import torchaudio import torchaudio.functional as F @@ -20,18 +15,18 @@ print(torchaudio.__version__) ###################################################################### -# Preparing data and utility functions (skip this section) -# -------------------------------------------------------- +# 1. Preparation +# -------------- # - -# @title Prepare data and utility functions. {display-mode: "form"} -# @markdown -# @markdown You do not need to look into this cell. -# @markdown Just execute once and you are good to go. - -# ------------------------------------------------------------------------------- -# Preparation of data and helper functions. -# ------------------------------------------------------------------------------- +# Firstly, we import the modules and define the helper functions. +# +# .. note:: +# When running this tutorial in Google Colab, install the required packages +# with the following. +# +# .. code:: +# +# !pip install librosa import math import time @@ -41,12 +36,7 @@ import pandas as pd from IPython.display import Audio, display - DEFAULT_OFFSET = 201 -SWEEP_MAX_SAMPLE_RATE = 48000 -DEFAULT_LOWPASS_FILTER_WIDTH = 6 -DEFAULT_ROLLOFF = 0.99 -DEFAULT_RESAMPLING_METHOD = "sinc_interpolation" def _get_log_freq(sample_rate, max_sweep_rate, offset): @@ -95,7 +85,7 @@ def plot_sweep( waveform, sample_rate, title, - max_sweep_rate=SWEEP_MAX_SAMPLE_RATE, + max_sweep_rate=48000, offset=DEFAULT_OFFSET, ): x_ticks = [100, 500, 1000, 5000, 10000, 20000, max_sweep_rate // 2] @@ -103,10 +93,10 @@ def plot_sweep( time, freq = _get_freq_ticks(max_sweep_rate, offset, sample_rate // 2) freq_x = [f if f in x_ticks and f <= max_sweep_rate // 2 else None for f in freq] - freq_y = [f for f in freq if f >= 1000 and f in y_ticks and f <= sample_rate // 2] + freq_y = [f for f in freq if f in y_ticks and 1000 <= f <= sample_rate // 2] figure, axis = plt.subplots(1, 1) - axis.specgram(waveform[0].numpy(), Fs=sample_rate) + _, _, _, cax = axis.specgram(waveform[0].numpy(), Fs=sample_rate) plt.xticks(time, freq_x) plt.yticks(freq_y, freq_y) axis.set_xlabel("Original Signal Frequency (Hz, log scale)") @@ -114,90 +104,13 @@ def plot_sweep( axis.xaxis.grid(True, alpha=0.67) axis.yaxis.grid(True, alpha=0.67) figure.suptitle(f"{title} (sample rate: {sample_rate} Hz)") + plt.colorbar(cax) plt.show(block=True) -def play_audio(waveform, sample_rate): - waveform = waveform.numpy() - - num_channels, num_frames = waveform.shape - if num_channels == 1: - display(Audio(waveform[0], rate=sample_rate)) - elif num_channels == 2: - display(Audio((waveform[0], waveform[1]), rate=sample_rate)) - else: - raise ValueError("Waveform with more than 2 channels are not supported.") - - -def plot_specgram(waveform, sample_rate, title="Spectrogram", xlim=None): - waveform = waveform.numpy() - - num_channels, num_frames = waveform.shape - - figure, axes = plt.subplots(num_channels, 1) - if num_channels == 1: - axes = [axes] - for c in range(num_channels): - axes[c].specgram(waveform[c], Fs=sample_rate) - if num_channels > 1: - axes[c].set_ylabel(f"Channel {c+1}") - if xlim: - axes[c].set_xlim(xlim) - figure.suptitle(title) - plt.show(block=False) - - -def benchmark_resample( - method, - waveform, - sample_rate, - resample_rate, - lowpass_filter_width=DEFAULT_LOWPASS_FILTER_WIDTH, - rolloff=DEFAULT_ROLLOFF, - resampling_method=DEFAULT_RESAMPLING_METHOD, - beta=None, - librosa_type=None, - iters=5, -): - if method == "functional": - begin = time.time() - for _ in range(iters): - F.resample( - waveform, - sample_rate, - resample_rate, - lowpass_filter_width=lowpass_filter_width, - rolloff=rolloff, - resampling_method=resampling_method, - ) - elapsed = time.time() - begin - return elapsed / iters - elif method == "transforms": - resampler = T.Resample( - sample_rate, - resample_rate, - lowpass_filter_width=lowpass_filter_width, - rolloff=rolloff, - resampling_method=resampling_method, - dtype=waveform.dtype, - ) - begin = time.time() - for _ in range(iters): - resampler(waveform) - elapsed = time.time() - begin - return elapsed / iters - elif method == "librosa": - waveform_np = waveform.squeeze().numpy() - begin = time.time() - for _ in range(iters): - librosa.resample(waveform_np, orig_sr=sample_rate, target_sr=resample_rate, res_type=librosa_type) - elapsed = time.time() - begin - return elapsed / iters - - ###################################################################### -# Resampling Overview -# ------------------- +# 2. Resampling Overview +# ---------------------- # # To resample an audio waveform from one freqeuncy to another, you can use # :py:func:`torchaudio.transforms.Resample` or @@ -211,11 +124,14 @@ def benchmark_resample( # interpolation `__ to compute # signal values at arbitrary time steps. The implementation involves # convolution, so we can take advantage of GPU / multithreading for -# performance improvements. When using resampling in multiple -# subprocesses, such as data loading with multiple worker processes, your -# application might create more threads than your system can handle -# efficiently. Setting ``torch.set_num_threads(1)`` might help in this -# case. +# performance improvements. +# +# .. note:: +# +# When using resampling in multiple subprocesses, such as data loading +# with multiple worker processes, your application might create more +# threads than your system can handle efficiently. +# Setting ``torch.set_num_threads(1)`` might help in this case. # # Because a finite number of samples can only represent a finite number of # frequencies, resampling does not produce perfect results, and a variety @@ -230,26 +146,32 @@ def benchmark_resample( # plotted waveform, and color intensity the amplitude. # - sample_rate = 48000 -resample_rate = 32000 - waveform = get_sine_sweep(sample_rate) + plot_sweep(waveform, sample_rate, title="Original Waveform") -play_audio(waveform, sample_rate) +Audio(waveform.numpy()[0], rate=sample_rate) +###################################################################### +# +# Now we resample (downsample) it. +# +# We see that in the spectrogram of the resampled waveform, there is an +# artifact, which was not present in the original waveform. + +resample_rate = 32000 resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype) resampled_waveform = resampler(waveform) -plot_sweep(resampled_waveform, resample_rate, title="Resampled Waveform") -play_audio(waveform, sample_rate) +plot_sweep(resampled_waveform, resample_rate, title="Resampled Waveform") +Audio(resampled_waveform.numpy()[0], rate=resample_rate) ###################################################################### -# Controling resampling quality with parameters -# --------------------------------------------- +# 3. Controling resampling quality with parameters +# ------------------------------------------------ # -# Lowpass filter width -# ~~~~~~~~~~~~~~~~~~~~ +# 3.1. Lowpass filter width +# ------------------------- # # Because the filter used for interpolation extends infinitely, the # ``lowpass_filter_width`` parameter is used to control for the width of @@ -260,20 +182,21 @@ def benchmark_resample( # expensive. # - sample_rate = 48000 resample_rate = 32000 resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=6) plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=6") +###################################################################### +# + resampled_waveform = F.resample(waveform, sample_rate, resample_rate, lowpass_filter_width=128) plot_sweep(resampled_waveform, resample_rate, title="lowpass_filter_width=128") - ###################################################################### -# Rolloff -# ~~~~~~~ +# 3.2. Rolloff +# ------------ # # The ``rolloff`` parameter is represented as a fraction of the Nyquist # frequency, which is the maximal frequency representable by a given @@ -291,13 +214,16 @@ def benchmark_resample( resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.99) plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.99") +###################################################################### +# + resampled_waveform = F.resample(waveform, sample_rate, resample_rate, rolloff=0.8) plot_sweep(resampled_waveform, resample_rate, title="rolloff=0.8") ###################################################################### -# Window function -# ~~~~~~~~~~~~~~~ +# 3.3. Window function +# -------------------- # # By default, ``torchaudio``’s resample uses the Hann window filter, which is # a weighted cosine function. It additionally supports the Kaiser window, @@ -314,23 +240,28 @@ def benchmark_resample( resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="sinc_interpolation") plot_sweep(resampled_waveform, resample_rate, title="Hann Window Default") +###################################################################### +# + resampled_waveform = F.resample(waveform, sample_rate, resample_rate, resampling_method="kaiser_window") plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Default") ###################################################################### -# Comparison against librosa -# -------------------------- +# 4. Comparison against librosa +# ----------------------------- # # ``torchaudio``’s resample function can be used to produce results similar to # that of librosa (resampy)’s kaiser window resampling, with some noise # - sample_rate = 48000 resample_rate = 32000 -# kaiser_best +###################################################################### +# 4.1. kaiser_best +# ---------------- +# resampled_waveform = F.resample( waveform, sample_rate, @@ -342,15 +273,24 @@ def benchmark_resample( ) plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Best (torchaudio)") +###################################################################### +# + librosa_resampled_waveform = torch.from_numpy( librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_best") ).unsqueeze(0) plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Best (librosa)") +###################################################################### +# + mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item() print("torchaudio and librosa kaiser best MSE:", mse) -# kaiser_fast +###################################################################### +# 4.2. kaiser_fast +# ---------------- +# resampled_waveform = F.resample( waveform, sample_rate, @@ -360,20 +300,25 @@ def benchmark_resample( resampling_method="kaiser_window", beta=8.555504641634386, ) -plot_specgram(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)") +plot_sweep(resampled_waveform, resample_rate, title="Kaiser Window Fast (torchaudio)") + +###################################################################### +# librosa_resampled_waveform = torch.from_numpy( librosa.resample(waveform.squeeze().numpy(), orig_sr=sample_rate, target_sr=resample_rate, res_type="kaiser_fast") ).unsqueeze(0) plot_sweep(librosa_resampled_waveform, resample_rate, title="Kaiser Window Fast (librosa)") +###################################################################### +# + mse = torch.square(resampled_waveform - librosa_resampled_waveform).mean().item() print("torchaudio and librosa kaiser fast MSE:", mse) - ###################################################################### -# Performance Benchmarking -# ------------------------ +# 5. Performance Benchmarking +# --------------------------- # # Below are benchmarks for downsampling and upsampling waveforms between # two pairs of sampling rates. We demonstrate the performance implications @@ -394,6 +339,57 @@ def benchmark_resample( # +def benchmark_resample( + method, + waveform, + sample_rate, + resample_rate, + lowpass_filter_width=6, + rolloff=0.99, + resampling_method="sinc_interpolation", + beta=None, + librosa_type=None, + iters=5, +): + if method == "functional": + begin = time.monotonic() + for _ in range(iters): + F.resample( + waveform, + sample_rate, + resample_rate, + lowpass_filter_width=lowpass_filter_width, + rolloff=rolloff, + resampling_method=resampling_method, + ) + elapsed = time.monotonic() - begin + return elapsed / iters + elif method == "transforms": + resampler = T.Resample( + sample_rate, + resample_rate, + lowpass_filter_width=lowpass_filter_width, + rolloff=rolloff, + resampling_method=resampling_method, + dtype=waveform.dtype, + ) + begin = time.monotonic() + for _ in range(iters): + resampler(waveform) + elapsed = time.monotonic() - begin + return elapsed / iters + elif method == "librosa": + waveform_np = waveform.squeeze().numpy() + begin = time.monotonic() + for _ in range(iters): + librosa.resample(waveform_np, orig_sr=sample_rate, target_sr=resample_rate, res_type=librosa_type) + elapsed = time.monotonic() - begin + return elapsed / iters + + +###################################################################### +# + configs = { "downsample (48 -> 44.1 kHz)": [48000, 44100], "downsample (16 -> 8 kHz)": [16000, 8000],