Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cm] Cached Mel spec implemetation #82

Merged
merged 6 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions neutone_sdk/cached_mel_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import logging
import os
from typing import Optional, Callable

import torch as tr
from torch import Tensor
from torch import nn
from torchaudio.transforms import MelSpectrogram

from neutone_sdk import CircularInplaceTensorQueue

logging.basicConfig()
log = logging.getLogger(__name__)
log.setLevel(level=os.environ.get("LOGLEVEL", "INFO"))


class CachedMelSpec(nn.Module):
def __init__(
self,
sr: int,
n_ch: int,
n_fft: int = 2048,
hop_len: int = 512,
f_min: float = 0.0,
f_max: Optional[float] = None,
n_mels: int = 128,
window_fn: Callable[..., Tensor] = tr.hann_window,
power: float = 2.0,
normalized: bool = False,
center: bool = True,
use_debug_mode: bool = True,
) -> None:
"""
Creates a Mel spectrogram that supports streaming of a centered, non-causal
Mel spectrogram operation that uses zero padding. Using this will result in
audio being delayed by (n_fft / 2) - hop_len samples. When calling forward,
the input audio block length must be a multiple of the hop length.

Parameters:
sr (int): Sample rate of the audio
n_ch (int): Number of audio channels
n_fft (int): STFT n_fft (must be even)
hop_len (int): STFT hop length (must divide into n_fft // 2)
f_min (float): Minimum frequency of the Mel filterbank
f_max (float): Maximum frequency of the Mel filterbank
n_mels (int): Number of mel filterbank bins
window_fn (Callable[..., Tensor]): A function to create a window tensor
power (float): Exponent for the magnitude spectrogram (must be > 0)
normalized (bool): Whether to normalize the mel spectrogram or not
center (bool): Whether to center the mel spectrogram (must be True)
use_debug_mode (bool): Whether to use debug mode or not
"""
super().__init__()
assert center, "center must be True, causal mode is not supported yet"
assert n_fft % 2 == 0, "n_fft must be even"
assert (n_fft // 2) % hop_len == 0, "n_fft // 2 must be divisible by hop_len"
self.n_ch = n_ch
self.n_fft = n_fft
self.hop_len = hop_len
self.use_debug_mode = use_debug_mode
self.mel_spec = MelSpectrogram(
sample_rate=sr,
n_fft=n_fft,
hop_length=hop_len,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
window_fn=window_fn,
power=power,
normalized=normalized,
center=False, # We use a causal STFT since we do the padding ourselves
)
self.padding_n_samples = self.n_fft - self.hop_len
self.cache = CircularInplaceTensorQueue(
n_ch, self.padding_n_samples, use_debug_mode
)
self.register_buffer("padding", tr.zeros((n_ch, self.padding_n_samples)))
self.cache.push(self.padding)

def forward(self, x: Tensor) -> Tensor:
"""
Computes the Mel spectrogram of the input audio tensor. Supports streaming as
long as the input audio tensor is a multiple of the hop length.
"""
if self.use_debug_mode:
assert x.ndim == 2, "input audio must have shape (n_ch, n_samples)"
assert x.size(0) == self.n_ch, "input audio n_ch is incorrect"
assert (
x.size(1) % self.hop_len == 0
), "input audio n_samples must be divisible by hop_len"
# Compute the Mel spec
n_samples = x.size(1)
n_frames = n_samples // self.hop_len
padded_x = tr.cat([self.padding, x], dim=1)
padded_spec = self.mel_spec(padded_x)
spec = padded_spec[:, :, -n_frames:]

# Update the cache and padding
padding_idx = min(n_samples, self.padding_n_samples)
self.cache.push(x[:, -padding_idx:])
self.cache.fill(self.padding)
return spec

def prepare_for_inference(self) -> None:
"""
Prepares the cached Mel spectrogram for inference by disabling debug mode.
"""
self.cache.use_debug_mode = False
self.use_debug_mode = False

@tr.jit.export
def get_delay_samples(self) -> int:
"""
Returns the number of samples of delay of the cached Mel spectrogram.
"""
return (self.n_fft // 2) - self.hop_len

@tr.jit.export
def get_delay_frames(self) -> int:
"""
Returns the number of frames of delay of the cached Mel spectrogram.
"""
return self.get_delay_samples() // self.hop_len

@tr.jit.export
def reset(self) -> None:
"""
Resets the cache and padding of the cached Mel spectrogram.
"""
self.cache.reset()
self.padding.zero_()
self.cache.push(self.padding)
84 changes: 84 additions & 0 deletions testing/test_cached_mel_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
import os

import torch as tr
from torchaudio.transforms import MelSpectrogram

from neutone_sdk.cached_mel_spec import CachedMelSpec

logging.basicConfig()
log = logging.getLogger(__name__)
log.setLevel(level=os.environ.get("LOGLEVEL", "INFO"))


def test_cached_mel_spec():
# Setup
tr.set_printoptions(precision=1)
tr.random.manual_seed(42)

sr = 44100
n_ch = 1
n_fft = 2048
hop_len = 128
n_mels = 16
total_n_samples = 1000 * hop_len

audio = tr.rand(n_ch, total_n_samples)
# log.info(f"audio = {audio}")
mel_spec = MelSpectrogram(
sample_rate=sr,
n_fft=n_fft,
hop_length=hop_len,
n_mels=n_mels,
center=True,
pad_mode="constant",
)
cached_mel_spec = CachedMelSpec(
sr=sr, n_ch=n_ch, n_fft=n_fft, hop_len=hop_len, n_mels=n_mels
)

# Test delay
delay_samples = cached_mel_spec.get_delay_samples()
assert delay_samples == n_fft // 2 - hop_len

# Test processing all audio at once
spec = mel_spec(audio)
delay_frames = cached_mel_spec.get_delay_frames()
cached_spec = cached_mel_spec(audio)
cached_spec = cached_spec[:, :, delay_frames:]
# log.info(f" spec = {spec}")
# log.info(f"cached_spec = {cached_spec}")
assert tr.allclose(spec[:, :, : cached_spec.size(2)], cached_spec)
cached_mel_spec.reset()

# Test processing audio in chunks (random chunk size)
chunks = []
min_chunk_size = 1
max_chunk_size = 100
curr_idx = 0
while curr_idx < total_n_samples - max_chunk_size:
chunk_size = (
tr.randint(min_chunk_size, max_chunk_size + 1, (1,)).item() * hop_len
)
chunks.append(audio[:, curr_idx : curr_idx + chunk_size])
curr_idx += chunk_size
if curr_idx < total_n_samples:
chunks.append(audio[:, curr_idx:])
chunks.append(
tr.zeros(n_ch, cached_mel_spec.get_delay_samples() + cached_mel_spec.hop_len)
)

spec_chunks = []
for chunk in chunks:
spec_chunk = cached_mel_spec(chunk)
spec_chunks.append(spec_chunk)
chunked_spec = tr.cat(spec_chunks, dim=2)
chunked_spec = chunked_spec[:, :, delay_frames:]
# log.info(f" spec = {spec}")
# log.info(f"chunked_spec = {chunked_spec}")
assert tr.allclose(spec, chunked_spec)
log.info("test_cached_mel_spec passed!")


if __name__ == "__main__":
test_cached_mel_spec()