From 9185447f14a473e994fbbb43236ee34ba60409d2 Mon Sep 17 00:00:00 2001 From: christhetree Date: Tue, 23 Jul 2024 10:55:13 +0100 Subject: [PATCH] [cm] Decreasing delay --- neutone_sdk/cached_mel_spec.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/neutone_sdk/cached_mel_spec.py b/neutone_sdk/cached_mel_spec.py index 2057b38..99c55e0 100644 --- a/neutone_sdk/cached_mel_spec.py +++ b/neutone_sdk/cached_mel_spec.py @@ -45,9 +45,12 @@ def __init__( n_mels=n_mels, center=False, ) - self.queue = CircularInplaceTensorQueue(n_ch, self.n_fft, use_debug_mode) - self.register_buffer("padding", tr.zeros((n_ch, self.n_fft))) - self.queue.push(self.padding) + self.padding_n_samples = self.n_fft - self.hop_len + self.cache = CircularInplaceTensorQueue( + n_ch, self.padding_n_samples, use_debug_mode + ) + self.register_buffer("padding", tr.zeros((n_ch, self.padding_n_samples))) + self.cache.push(self.padding) def forward(self, x: Tensor) -> Tensor: if self.use_debug_mode: @@ -61,16 +64,16 @@ def forward(self, x: Tensor) -> Tensor: padded_x = tr.cat([self.padding, x], dim=1) # log.info(f"padded_x = {padded_x}") spec = self.mel_spec(padded_x) - spec = spec[:, :, :n_frames] + spec = spec[:, :, -n_frames:] - lookahead_idx = min(n_samples, self.n_fft) - self.queue.push(x[:, -lookahead_idx:]) - self.queue.fill(self.padding) + padding_idx = min(n_samples, self.padding_n_samples) + self.cache.push(x[:, -padding_idx:]) + self.cache.fill(self.padding) return spec @tr.jit.export def get_delay_samples(self) -> int: - return self.n_fft // 2 + return (self.n_fft // 2) - self.hop_len @tr.jit.export def get_delay_frames(self) -> int: @@ -78,13 +81,13 @@ def get_delay_frames(self) -> int: @tr.jit.export def reset(self) -> None: - self.queue.reset() + self.cache.reset() self.padding.zero_() - self.queue.push(self.padding) + self.cache.push(self.padding) def test_cached_mel_spec(): - tr.set_printoptions(precision=2) + tr.set_printoptions(precision=1) tr.random.manual_seed(0) sr = 44100 @@ -127,7 +130,9 @@ def test_cached_mel_spec(): curr_idx += chunk_size if curr_idx < total_n_samples: chunks.append(audio[:, curr_idx:]) - chunks.append(tr.zeros(n_ch, cached_mel_spec.n_fft)) + chunks.append( + tr.zeros(n_ch, cached_mel_spec.get_delay_samples() + cached_mel_spec.hop_len) + ) spec_chunks = [] for chunk in chunks: