Skip to content

Commit

Permalink
[cm] Fixing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Jul 22, 2024
1 parent e712221 commit 57428fa
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions neutone_sdk/cached_mel_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ def forward(self, x: Tensor) -> Tensor:
), "input audio n_samples must be divisible by hop_len"
n_samples = x.size(1)
n_frames = n_samples // self.hop_len
lookahead_idx = min(n_samples, self.n_fft)
padded_x = tr.cat([self.padding, x], dim=1)
# log.info(f"padded_x = {padded_x}")
spec = self.mel_spec(padded_x)
spec = spec[:, :, :n_frames]
self.queue.push(x[:, :lookahead_idx])

lookahead_idx = min(n_samples, self.n_fft)
self.queue.push(x[:, -lookahead_idx:])
self.queue.fill(self.padding)
return spec

Expand All @@ -83,14 +84,15 @@ def reset(self) -> None:


def test_cached_mel_spec():
tr.set_printoptions(precision=2)
tr.random.manual_seed(0)

sr = 44100
n_ch = 1
n_fft = 2048
hop_len = 128
n_mels = 1
total_n_samples = 100 * hop_len
n_mels = 16
total_n_samples = 1000 * hop_len

audio = tr.rand(n_ch, total_n_samples)
# log.info(f"audio = {audio}")
Expand All @@ -108,12 +110,14 @@ def test_cached_mel_spec():
delay_frames = cached_mel_spec.get_delay_frames()
cached_spec = cached_mel_spec(audio)
cached_spec = cached_spec[:, :, delay_frames:]
assert tr.allclose(spec[:, :, :cached_spec.size(2)], cached_spec)
# log.info(f" spec = {spec}")
# log.info(f"cached_spec = {cached_spec}")
assert tr.allclose(spec[:, :, : cached_spec.size(2)], cached_spec)
cached_mel_spec.reset()

chunks = []
min_chunk_size = 1
max_chunk_size = 10
max_chunk_size = 100
curr_idx = 0
while curr_idx < total_n_samples - max_chunk_size:
chunk_size = (
Expand All @@ -129,11 +133,11 @@ def test_cached_mel_spec():
for chunk in chunks:
spec_chunk = cached_mel_spec(chunk)
spec_chunks.append(spec_chunk)
chunked_mel_spec = tr.cat(spec_chunks, dim=2)
chunked_mel_spec = chunked_mel_spec[:, :, delay_frames:]
# log.info(f" spec = {spec}")
# log.info(f"chunked_mel_spec = {chunked_mel_spec}")
assert tr.allclose(spec, chunked_mel_spec[:, :, : spec.size(2)])
chunked_spec = tr.cat(spec_chunks, dim=2)
chunked_spec = chunked_spec[:, :, delay_frames:]
# log.info(f" spec = {spec}")
# log.info(f"chunked_spec = {chunked_spec}")
assert tr.allclose(spec, chunked_spec[:, :, : spec.size(2)])


if __name__ == "__main__":
Expand Down

0 comments on commit 57428fa

Please sign in to comment.