Skip to content

Commit

Permalink
feat: add stftautoencoder
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 25, 2022
1 parent aa83393 commit fc1fb9b
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 1 deletion.
1 change: 1 addition & 0 deletions audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Encoder1d,
MultiEncoder1d,
Noiser,
STFTAutoEncoder1d,
T5Embedder,
Tanh,
UNet1d,
Expand Down
96 changes: 96 additions & 0 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,3 +1687,99 @@ def decode(self, latent: Tensor) -> List[Tensor]:
x = self.to_out(x)
channels_list += [x]
return channels_list[::-1]


class STFT(nn.Module):
def __init__(
self,
length: int,
num_fft: int = 1024,
hop_length: int = 256,
window_length: int = 1024,
):
super().__init__()
self.num_fft = num_fft
self.hop_length = hop_length
self.window_length = window_length
self.length = length
self.register_buffer("window", torch.hann_window(window_length))

def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
b = wave.shape[0]
wave = rearrange(wave, "b c t -> (b c) t")

stft = torch.stft(
wave,
n_fft=self.num_fft,
hop_length=self.hop_length,
win_length=self.window_length,
window=self.window, # type: ignore
return_complex=True,
)

mag = torch.sqrt(torch.clamp((stft.real ** 2) + (stft.imag ** 2), min=1e-8))
mag = rearrange(mag, "(b c) f l -> b c f l", b=b)

phase = torch.angle(stft)
phase = rearrange(phase, "(b c) f l -> b c f l", b=b)
return mag, phase

def decode(self, magnitude: Tensor, phase: Tensor) -> Tensor:
b = magnitude.shape[0]
assert magnitude.shape == phase.shape, "magnitude and phase must be same shape"
real = rearrange(magnitude * torch.cos(phase), "b c f l -> (b c) f l")
imag = rearrange(magnitude * torch.sin(phase), "b c f l -> (b c) f l")
stft = torch.stack([real, imag], dim=-1)

wave = torch.istft(
stft,
n_fft=self.num_fft,
hop_length=self.hop_length,
win_length=self.window_length,
window=self.window, # type: ignore
length=self.length,
)
wave = rearrange(wave, "(b c) t -> b c t", b=b)
return wave


class STFTAutoEncoder1d(AutoEncoder1d):
def __init__(
self,
in_channels: int,
length: int,
num_fft: int = 1024,
hop_length: int = 256,
window_length: int = 1024,
**kwargs,
):
self.frequency_channels = num_fft // 2 + 1

super().__init__(
in_channels=in_channels * self.frequency_channels,
out_channels=in_channels * self.frequency_channels * 2,
patch_blocks=1,
patch_factor=1,
**kwargs,
)

self.stft = STFT(
num_fft=num_fft,
hop_length=hop_length,
window_length=window_length,
length=length,
)

def encode(
self, wave: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
magnitude, phase = self.stft.encode(wave)
log_magnitude = rearrange(torch.log(magnitude), "b c f t -> b (c f) t")
return super().encode(log_magnitude, with_info)

def decode(self, z: Tensor) -> Tensor:
f = self.frequency_channels
stft = super().decode(z)
stft = rearrange(stft, "b (c f i) t -> b (c i) f t", i=2, f=f)
log_magnitude, phase = stft.chunk(chunks=2, dim=1)
return self.stft.decode(magnitude=torch.exp(log_magnitude), phase=phase)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.76",
version="0.0.77",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit fc1fb9b

Please sign in to comment.