Skip to content

Commit

Permalink
feat: back to previous vocoder
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Nov 1, 2022
1 parent d0b206a commit 015b152
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 29 deletions.
65 changes: 38 additions & 27 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d
from .utils import (
closest_power_2,
default,
downsample,
exists,
Expand Down Expand Up @@ -204,7 +203,7 @@ def encode(
) -> Union[Tensor, Tuple[Tensor, Any]]:
latent, info = self.encoder(x, with_info=True)
for bottleneck in self.bottlenecks:
x, info_bottleneck = bottleneck(x, with_info=True)
latent, info_bottleneck = bottleneck(latent, with_info=True)
info = {**info, **prefix_dict("bottleneck_", info_bottleneck)}
return (latent, info) if with_info else latent

Expand All @@ -226,31 +225,43 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor:


class DiffusionVocoder1d(Model1d):
def __init__(self, in_channels: int, stft_num_fft: int, **kwargs):
self.stft_num_fft = stft_num_fft
spectrogram_channels = stft_num_fft // 2 + 1
def __init__(
self,
in_channels: int,
stft_num_fft: int,
**kwargs,
):
self.frequency_channels = stft_num_fft // 2 + 1
spectrogram_channels = in_channels * self.frequency_channels

stft_kwargs, kwargs = groupby("stft_", kwargs)
default_kwargs = dict(
in_channels=in_channels,
use_stft=True,
stft_num_fft=stft_num_fft,
context_channels=[in_channels * spectrogram_channels],
in_channels=spectrogram_channels, context_channels=[spectrogram_channels]
)

super().__init__(**{**default_kwargs, **kwargs}) # type: ignore
self.stft = STFT(num_fft=stft_num_fft, **stft_kwargs)

def forward(self, x: Tensor, **kwargs) -> Tensor:
# Get magnitude spectrogram from true wave
magnitude, _ = self.unet.stft.encode(x)
magnitude = rearrange(magnitude, "b c f t -> b (c f) t")
# Get diffusion loss while conditioning on magnitude
return self.diffusion(x, channels_list=[magnitude], **kwargs)
def forward_wave(self, x: Tensor, **kwargs) -> Tensor:
# Get magnitude and phase of true wave
magnitude, phase = self.stft.encode(x)
return self(magnitude, phase, **kwargs)

def sample(self, spectrogram: Tensor, **kwargs): # type: ignore
b, c, _, t, device = *spectrogram.shape, spectrogram.device
magnitude = rearrange(spectrogram, "b c f t -> b (c f) t")
timesteps = closest_power_2(self.unet.stft.hop_length * t)
noise = torch.randn((b, c, timesteps), device=device)
default_kwargs = dict(channels_list=[magnitude])
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa
def forward(self, magnitude: Tensor, phase: Tensor, **kwargs) -> Tensor: # type: ignore # noqa
magnitude = rearrange(magnitude, "b c f t -> b (c f) t")
phase = rearrange(phase, "b c f t -> b (c f) t")
# Get diffusion phase loss while conditioning on magnitude (/pi [-1,1] range)
return self.diffusion(phase / pi, channels_list=[magnitude], **kwargs)

def sample(self, magnitude: Tensor, **kwargs): # type: ignore
b, c, f, t, device = *magnitude.shape, magnitude.device
magnitude_flat = rearrange(magnitude, "b c f t -> b (c f) t")
noise = torch.randn((b, c * f, t), device=device)
default_kwargs = dict(channels_list=[magnitude_flat])
phase_flat = super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore # noqa
phase = rearrange(phase_flat, "b (c f) t -> b c f t", c=c)
wave = self.stft.decode(magnitude, phase * pi)
return wave


class DiffusionUpphaser1d(DiffusionUpsampler1d):
Expand Down Expand Up @@ -371,13 +382,13 @@ def __init__(self, in_channels: int, **kwargs):
in_channels=in_channels,
stft_num_fft=1023,
stft_hop_length=256,
channels=64,
channels=512,
patch_blocks=1,
patch_factor=1,
multipliers=[48, 32, 16, 8, 8, 8, 8],
factors=[2, 2, 2, 1, 1, 1],
num_blocks=[1, 1, 1, 1, 1, 1],
attentions=[0, 0, 0, 1, 1, 1],
multipliers=[3, 2, 1, 1, 1, 1, 1, 1],
factors=[1, 2, 2, 2, 2, 2, 2],
num_blocks=[1, 1, 1, 1, 1, 1, 1],
attentions=[0, 0, 0, 0, 1, 1, 1],
attention_heads=8,
attention_features=64,
attention_multiplier=2,
Expand Down
2 changes: 1 addition & 1 deletion audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,7 @@ def get_channels(
assert exists(channels), message
# Check channels
num_channels = self.context_channels[layer]
message = f"Expected context with {channels} channels at index {channels_id}"
message = f"Expected context with {num_channels} channels at idx {channels_id}"
assert channels.shape[1] == num_channels, message
# STFT channels if requested
channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.81",
version="0.0.82",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 015b152

Please sign in to comment.