Skip to content

Commit

Permalink
feat: use inverse window to get perfect reconstruction, real output w…
Browse files Browse the repository at this point in the history
…ith irfft
  • Loading branch information
flavioschneider committed Dec 8, 2022
1 parent fe6f139 commit e6dad63
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 42 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ y = transform.decode(z) # [1, 1, 262144]

## TODO
* [x] Power of 2 length (with `power_of_2_length` constructor arg).
* [ ] Understand why/if inverse window is necessary.
* [x] Understand why/if inverse window is necessary (it is necessary for perfect inversion).
* [ ] Allow variable audio lengths by chunking.

## Appreciation
Expand Down
85 changes: 45 additions & 40 deletions cqt_pytorch/cqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch import Tensor, einsum, nn


def next_power_of_2(x: Tensor) -> int:
return 2 ** ceil(x.item()).bit_length()
def next_power_of_2(x: float) -> int:
return 2 ** ceil(x).bit_length()


def get_center_frequencies(
Expand All @@ -22,7 +22,7 @@ def get_center_frequencies(
[
frequencies,
torch.tensor([frequency_nyquist]),
# sample_rate - torch.flip(frequencies, dims=[0]) # not necessary
sample_rate - torch.flip(frequencies, dims=[0]), # not necessary
],
dim=0,
)
Expand All @@ -37,13 +37,9 @@ def get_bandwidths(
) -> Tensor: # Omega_k for k in [1, 2*K+1]
"""Compute bandwidths tensor from center frequencies"""
num_bins = num_octaves * num_bins_per_octave # K
q_factor = 1.0 / (
2 ** (1.0 / num_bins_per_octave) - 2 ** (-1.0 / num_bins_per_octave)
)
bandwidths = frequencies[1 : num_bins + 1] / q_factor
bandwidths_symmetric = (
torch.flip(frequencies[1 : num_bins + 1], dims=[0]) / q_factor
)
q = 1.0 / (2 ** (1.0 / num_bins_per_octave) - 2 ** (-1.0 / num_bins_per_octave))
bandwidths = frequencies[1 : num_bins + 1] / q
bandwidths_symmetric = torch.flip(frequencies[1 : num_bins + 1], dims=[0]) / q
bandwidths_all = torch.cat(
[
bandwidths,
Expand All @@ -55,23 +51,19 @@ def get_bandwidths(
return bandwidths_all


def get_windows_range_indices(
lengths: Tensor, positions: Tensor, power_of_2_length: bool
) -> Tensor:
def get_windows_range_indices(positions: Tensor, max_length: int) -> Tensor:
"""Compute windowing tensor of indices"""
num_bins = lengths.shape[0] // 2
max_length = next_power_of_2(lengths.max()) if power_of_2_length else lengths.max()
num_bins = positions.shape[0] // 2
ranges = []
for i in range(num_bins):
start = positions[i] - max_length
ranges += [torch.arange(start=start, end=start + max_length)] # type: ignore
start = positions[i] - max_length // 2
ranges += [torch.arange(start=start, end=start + max_length)] # type: ignore # noqa
return torch.stack(ranges, dim=0).long()


def get_windows(lengths: Tensor, power_of_2_length: bool) -> Tensor:
def get_windows(lengths: Tensor, max_length: int) -> Tensor:
"""Compute tensor of stacked (centered) windows"""
num_bins = lengths.shape[0] // 2
max_length = next_power_of_2(lengths.max()) if power_of_2_length else lengths.max()
windows = []
for length in lengths[:num_bins]:
# Pad windows left and right to center them
Expand All @@ -81,9 +73,16 @@ def get_windows(lengths: Tensor, power_of_2_length: bool) -> Tensor:
return torch.stack(windows, dim=0)


def get_windows_inverse(windows: Tensor, lengths: Tensor) -> Tensor:
num_bins = windows.shape[0]
return torch.einsum("k m, k -> k m", windows**2, lengths[:num_bins])
def get_windows_inverse(
windows: Tensor, windows_range_indices: Tensor, max_length: int, block_length: int
) -> Tensor:
"""Compute tensor of stacked (centered) inverse windows"""
windows_overlap = torch.zeros(block_length).scatter_add_(
dim=0,
index=windows_range_indices.view(-1),
src=(windows**2).view(-1),
)
return windows / (windows_overlap[windows_range_indices] + 1e-8)


class CQT(nn.Module):
Expand Down Expand Up @@ -112,42 +111,48 @@ def __init__(
)

window_lengths = torch.round(bandwidths * block_length / sample_rate)
max_window_length = int(window_lengths.max())

if power_of_2_length:
max_window_length = next_power_of_2(max_window_length)

self.register_buffer(
"windows_range_indices",
windows_range_indices = (
get_windows_range_indices(
lengths=window_lengths,
max_length=max_window_length,
positions=torch.round(frequencies * block_length / sample_rate),
power_of_2_length=power_of_2_length,
),
)
% block_length
)

self.register_buffer(
"windows",
get_windows(lengths=window_lengths, power_of_2_length=power_of_2_length),
)
windows = get_windows(lengths=window_lengths, max_length=max_window_length)

self.register_buffer(
"windows_inverse",
get_windows_inverse(windows=self.windows, lengths=window_lengths), # type: ignore # noqa
windows_inverse = get_windows_inverse(
windows=windows,
windows_range_indices=windows_range_indices,
max_length=max_window_length,
block_length=block_length,
)

self.register_buffer("windows_range_indices", windows_range_indices)
self.register_buffer("windows", windows)
self.register_buffer("windows_inverse", windows_inverse)

def encode(self, waveform: Tensor) -> Tensor:
frequencies = torch.fft.fft(waveform)
crops = frequencies[:, :, self.windows_range_indices]
crops_windowed = torch.einsum("... t k, t k -> ... t k", crops, self.windows)
crops_windowed = einsum("... t k, t k -> ... t k", crops, self.windows)
transform = torch.fft.ifft(crops_windowed)
return transform

def decode(self, transform: Tensor) -> Tensor:
b, c, length = *transform.shape[0:2], self.block_length
crops_windowed = torch.fft.fft(transform)
crops_unwindowed = crops_windowed # TODO crops_unwindowed = torch.einsum('... t k, t k -> ... t k', transformed, self.windows_inverse) # noqa
crops = einsum("... t k, t k -> ... t k", crops_windowed, self.windows_inverse)
frequencies = torch.zeros(b, c, length).to(transform)
frequencies.scatter_add_(
dim=-1,
index=self.windows_range_indices.view(-1).expand(b, c, -1) % length, # type: ignore # noqa
src=crops_unwindowed.view(b, c, -1),
index=self.windows_range_indices.view(-1).expand(b, c, -1), # type: ignore
src=crops.view(b, c, -1),
)
waveform = torch.fft.ifft(frequencies)
waveform = torch.fft.irfft(frequencies, n=length)
return waveform
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="cqt-pytorch",
packages=find_packages(exclude=[]),
version="0.0.3",
version="0.0.4",
license="MIT",
description="CQT Pytorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit e6dad63

Please sign in to comment.