Skip to content

Commit

Permalink
[cm] Decreasing delay
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Jul 23, 2024
1 parent 57428fa commit 9185447
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions neutone_sdk/cached_mel_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ def __init__(
n_mels=n_mels,
center=False,
)
self.queue = CircularInplaceTensorQueue(n_ch, self.n_fft, use_debug_mode)
self.register_buffer("padding", tr.zeros((n_ch, self.n_fft)))
self.queue.push(self.padding)
self.padding_n_samples = self.n_fft - self.hop_len
self.cache = CircularInplaceTensorQueue(
n_ch, self.padding_n_samples, use_debug_mode
)
self.register_buffer("padding", tr.zeros((n_ch, self.padding_n_samples)))
self.cache.push(self.padding)

def forward(self, x: Tensor) -> Tensor:
if self.use_debug_mode:
Expand All @@ -61,30 +64,30 @@ def forward(self, x: Tensor) -> Tensor:
padded_x = tr.cat([self.padding, x], dim=1)
# log.info(f"padded_x = {padded_x}")
spec = self.mel_spec(padded_x)
spec = spec[:, :, :n_frames]
spec = spec[:, :, -n_frames:]

lookahead_idx = min(n_samples, self.n_fft)
self.queue.push(x[:, -lookahead_idx:])
self.queue.fill(self.padding)
padding_idx = min(n_samples, self.padding_n_samples)
self.cache.push(x[:, -padding_idx:])
self.cache.fill(self.padding)
return spec

@tr.jit.export
def get_delay_samples(self) -> int:
return self.n_fft // 2
return (self.n_fft // 2) - self.hop_len

@tr.jit.export
def get_delay_frames(self) -> int:
return self.get_delay_samples() // self.hop_len

@tr.jit.export
def reset(self) -> None:
self.queue.reset()
self.cache.reset()
self.padding.zero_()
self.queue.push(self.padding)
self.cache.push(self.padding)


def test_cached_mel_spec():
tr.set_printoptions(precision=2)
tr.set_printoptions(precision=1)
tr.random.manual_seed(0)

sr = 44100
Expand Down Expand Up @@ -127,7 +130,9 @@ def test_cached_mel_spec():
curr_idx += chunk_size
if curr_idx < total_n_samples:
chunks.append(audio[:, curr_idx:])
chunks.append(tr.zeros(n_ch, cached_mel_spec.n_fft))
chunks.append(
tr.zeros(n_ch, cached_mel_spec.get_delay_samples() + cached_mel_spec.hop_len)
)

spec_chunks = []
for chunk in chunks:
Expand Down

0 comments on commit 9185447

Please sign in to comment.