Skip to content

Commit

Permalink
Remove torch dependency, Faster numpy Feature extraction (SYSTRAN#1106
Browse files Browse the repository at this point in the history
)
  • Loading branch information
MahmoudAshraf97 authored Nov 14, 2024
1 parent 8f01aee commit 3e0ba86
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 118 deletions.
21 changes: 6 additions & 15 deletions faster_whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import av
import numpy as np
import torch


def decode_audio(
Expand Down Expand Up @@ -72,9 +71,9 @@ def decode_audio(
if split_stereo:
left_channel = audio[0::2]
right_channel = audio[1::2]
return torch.from_numpy(left_channel), torch.from_numpy(right_channel)
return left_channel, right_channel

return torch.from_numpy(audio)
return audio


def _ignore_invalid_frames(frames):
Expand Down Expand Up @@ -113,20 +112,12 @@ def pad_or_trim(array, length: int = 3000, *, axis: int = -1):
"""
Pad or trim the Mel features array to 3000, as expected by the encoder.
"""
axis = axis % array.ndim
if array.shape[axis] > length:
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
return array[idx]
array = array.take(indices=range(length), axis=axis)

if array.shape[axis] < length:
pad_widths = (
[
0,
]
* array.ndim
* 2
)
pad_widths[2 * axis] = length - array.shape[axis]
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)

return array
206 changes: 161 additions & 45 deletions faster_whisper/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
import torch
import numpy as np


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
class FeatureExtractor:
def __init__(
self,
device: str = "auto",
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
):
if device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
Expand All @@ -25,24 +19,21 @@ def __init__(
self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(
sampling_rate, n_fft, n_mels=feature_size
)
).astype("float32")

@staticmethod
def get_mel_filters(sr, n_fft, n_mels=128):
"""
Implementation of librosa.filters.mel in Pytorch
"""
# Initialize the weights
n_mels = int(n_mels)

# Center freqs of each FFT bin
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr)
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)

# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965

mels = torch.linspace(min_mel, max_mel, n_mels + 2)
mels = np.linspace(min_mel, max_mel, n_mels + 2)

# Fill in the linear scale
f_min = 0.0
Expand All @@ -52,30 +43,159 @@ def get_mel_filters(sr, n_fft, n_mels=128):
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region
logstep = np.log(6.4) / 27.0 # step size for log region

# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

mel_f = freqs
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))

fdiff = torch.diff(mel_f)
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)
fdiff = np.diff(freqs)
ramps = freqs.reshape(-1, 1) - fftfreqs.reshape(1, -1)

lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
upper = ramps[2:] / fdiff[1:].unsqueeze(1)
lower = -ramps[:-2] / np.expand_dims(fdiff[:-1], axis=1)
upper = ramps[2:] / np.expand_dims(fdiff[1:], axis=1)

# Intersect them with each other and zero, vectorized across all i
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))
weights = np.maximum(np.zeros_like(lower), np.minimum(lower, upper))

# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm.unsqueeze(1)
enorm = 2.0 / (freqs[2 : n_mels + 2] - freqs[:n_mels])
weights *= np.expand_dims(enorm, axis=1)

return weights

def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
@staticmethod
def stft(
input_array: np.ndarray,
n_fft: int,
hop_length: int = None,
win_length: int = None,
window: np.ndarray = None,
center: bool = True,
mode: str = "reflect",
normalized: bool = False,
onesided: bool = None,
return_complex: bool = None,
):
# Default initialization for hop_length and win_length
hop_length = hop_length if hop_length is not None else n_fft // 4
win_length = win_length if win_length is not None else n_fft
input_is_complex = np.iscomplexobj(input_array)

# Determine if the output should be complex
return_complex = (
return_complex
if return_complex is not None
else (input_is_complex or (window is not None and np.iscomplexobj(window)))
)

if not return_complex and return_complex is None:
raise ValueError(
"stft requires the return_complex parameter for real inputs."
)

# Input checks
if not np.issubdtype(input_array.dtype, np.floating) and not input_is_complex:
raise ValueError(
"stft: expected an array of floating point or complex values,"
f" got {input_array.dtype}"
)

if input_array.ndim > 2 or input_array.ndim < 1:
raise ValueError(
f"stft: expected a 1D or 2D array, but got {input_array.ndim}D array"
)

# Handle 1D input
if input_array.ndim == 1:
input_array = np.expand_dims(input_array, axis=0)
input_array_1d = True
else:
input_array_1d = False

# Center padding if required
if center:
pad_amount = n_fft // 2
input_array = np.pad(
input_array, ((0, 0), (pad_amount, pad_amount)), mode=mode
)

batch, length = input_array.shape

# Additional input checks
if n_fft <= 0 or n_fft > length:
raise ValueError(
f"stft: expected 0 < n_fft <= {length}, but got n_fft={n_fft}"
)

if hop_length <= 0:
raise ValueError(
f"stft: expected hop_length > 0, but got hop_length={hop_length}"
)

if win_length <= 0 or win_length > n_fft:
raise ValueError(
f"stft: expected 0 < win_length <= n_fft, but got win_length={win_length}"
)

if window is not None:
if window.ndim != 1 or window.shape[0] != win_length:
raise ValueError(
f"stft: expected a 1D window array of size equal to win_length={win_length}, "
f"but got window with size {window.shape}"
)

# Handle padding of the window if necessary
if win_length < n_fft:
left = (n_fft - win_length) // 2
window_ = np.zeros(n_fft, dtype=window.dtype)
window_[left : left + win_length] = window
else:
window_ = window

# Calculate the number of frames
n_frames = 1 + (length - n_fft) // hop_length

# Time to columns
input_array = np.lib.stride_tricks.as_strided(
input_array,
(batch, n_frames, n_fft),
(
input_array.strides[0],
hop_length * input_array.strides[1],
input_array.strides[1],
),
)

if window_ is not None:
input_array = input_array * window_

# FFT and transpose
complex_fft = input_is_complex
onesided = onesided if onesided is not None else not complex_fft

if normalized:
norm = "ortho"
else:
norm = None

if complex_fft:
if onesided:
raise ValueError(
"Cannot have onesided output if window or input is complex"
)
output = np.fft.fft(input_array, n=n_fft, axis=-1, norm=norm)
else:
output = np.fft.rfft(input_array, n=n_fft, axis=-1, norm=norm)

output = output.transpose((0, 2, 1))

if input_array_1d:
output = output.squeeze(0)

return output if return_complex else np.real(output)

def __call__(self, waveform: np.ndarray, padding=160, chunk_length=None):
"""
Compute the log-Mel spectrogram of the provided audio.
"""
Expand All @@ -84,31 +204,27 @@ def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
self.n_samples = chunk_length * self.sampling_rate
self.nb_max_frames = self.n_samples // self.hop_length

if waveform.dtype is not torch.float32:
waveform = waveform.to(torch.float32)

waveform = (
waveform.to(self.device)
if self.device == "cuda" and not waveform.is_cuda
else waveform
)
if waveform.dtype is not np.float32:
waveform = waveform.astype(np.float32)

if padding:
waveform = torch.nn.functional.pad(waveform, (0, self.n_samples))
waveform = np.pad(waveform, (0, padding))

window = torch.hann_window(self.n_fft).to(waveform.device)
window = np.hanning(self.n_fft + 1)[:-1].astype("float32")

stft = torch.stft(
waveform, self.n_fft, self.hop_length, window=window, return_complex=True
)
magnitudes = stft[..., :-1].abs() ** 2
stft = self.stft(
waveform,
self.n_fft,
self.hop_length,
window=window,
return_complex=True,
).astype("complex64")
magnitudes = np.abs(stft[..., :-1]) ** 2

mel_spec = self.mel_filters.to(waveform.device) @ magnitudes
mel_spec = self.mel_filters @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

# When the model is running on multiple GPUs, the output should be moved
# to the CPU since we don't know which GPU will handle the next job.
return log_spec.cpu() if to_cpu else log_spec
return log_spec
Loading

0 comments on commit 3e0ba86

Please sign in to comment.