Skip to content

Commit

Permalink
feat: encoder1d, decoder1d, autoencoder bottleneck channels
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 23, 2022
1 parent d48ac1a commit 24ff00f
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 32 deletions.
2 changes: 2 additions & 0 deletions audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
)
from .modules import (
AutoEncoder1d,
Decoder1d,
Encoder1d,
MultiEncoder1d,
Noiser,
T5Embedder,
Expand Down
155 changes: 124 additions & 31 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,7 +1364,7 @@ def forward(
return (x, info) if with_info else x


class AutoEncoder1d(nn.Module):
class Encoder1d(nn.Module):
def __init__(
self,
in_channels: int,
Expand All @@ -1375,16 +1375,10 @@ def __init__(
multipliers: Sequence[int],
factors: Sequence[int],
num_blocks: Sequence[int],
use_noisy: bool = False,
bottleneck: Union[Bottleneck, List[Bottleneck]] = [],
use_magnitude_channels: bool = False,
out_channels: Optional[int] = None,
):
super().__init__()
num_layers = len(multipliers) - 1
self.bottlenecks = nn.ModuleList(to_list(bottleneck))
self.use_noisy = use_noisy
self.use_magnitude_channels = use_magnitude_channels

assert len(factors) >= num_layers and len(num_blocks) >= num_layers

self.to_in = Patcher(
Expand All @@ -1408,10 +1402,66 @@ def __init__(
]
)

self.to_out = (
nn.Conv1d(
in_channels=channels * multipliers[-1],
out_channels=out_channels,
kernel_size=1,
)
if exists(out_channels)
else nn.Identity()
)

def forward(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
xs = []
x = self.to_in(x)

for downsample in self.downsamples:
x = downsample(x)
xs += [x]

x = self.to_out(x)

info = dict(xs=xs)
return (x, info) if with_info else x


class Decoder1d(nn.Module):
def __init__(
self,
out_channels: int,
channels: int,
patch_blocks: int,
patch_factor: int,
resnet_groups: int,
multipliers: Sequence[int],
factors: Sequence[int],
num_blocks: Sequence[int],
use_magnitude_channels: bool = False,
in_channels: Optional[int] = None,
):
super().__init__()
num_layers = len(multipliers) - 1
self.use_magnitude_channels = use_magnitude_channels

assert len(factors) >= num_layers and len(num_blocks) >= num_layers

self.to_in = (
Conv1d(
in_channels=in_channels,
out_channels=channels * multipliers[-1],
kernel_size=1,
)
if exists(in_channels)
else nn.Identity()
)

self.upsamples = nn.ModuleList(
[
UpsampleBlock1d(
in_channels=channels * multipliers[i + 1] * (use_noisy + 1),
in_channels=channels * multipliers[i + 1],
out_channels=channels * multipliers[i],
factor=factors[i],
num_groups=resnet_groups,
Expand All @@ -1424,12 +1474,73 @@ def __init__(
)

self.to_out = Unpatcher(
in_channels=channels * (use_noisy + 1),
out_channels=in_channels * (2 if use_magnitude_channels else 1),
in_channels=channels,
out_channels=out_channels * (2 if use_magnitude_channels else 1),
blocks=patch_blocks,
factor=patch_factor,
)

def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Any]]:
x = self.to_in(x)

for upsample in self.upsamples:
x = upsample(x)

x = self.to_out(x)

if self.use_magnitude_channels:
x = merge_magnitude_channels(x)

return x


class AutoEncoder1d(nn.Module):
def __init__(
self,
in_channels: int,
channels: int,
patch_blocks: int,
patch_factor: int,
resnet_groups: int,
multipliers: Sequence[int],
factors: Sequence[int],
num_blocks: Sequence[int],
use_noisy: bool = False,
bottleneck: Union[Bottleneck, List[Bottleneck]] = [],
bottleneck_channels: Optional[int] = None,
use_magnitude_channels: bool = False,
):
super().__init__()
num_layers = len(multipliers) - 1
self.bottlenecks = nn.ModuleList(to_list(bottleneck))

assert len(factors) >= num_layers and len(num_blocks) >= num_layers

self.encoder = Encoder1d(
in_channels=in_channels,
channels=channels,
patch_blocks=patch_blocks,
patch_factor=patch_factor,
resnet_groups=resnet_groups,
multipliers=multipliers,
factors=factors,
num_blocks=num_blocks,
out_channels=bottleneck_channels,
)

self.decoder = Decoder1d(
in_channels=bottleneck_channels,
out_channels=in_channels,
channels=channels,
patch_blocks=patch_blocks,
patch_factor=patch_factor,
resnet_groups=resnet_groups,
multipliers=multipliers,
factors=factors,
num_blocks=num_blocks,
use_magnitude_channels=use_magnitude_channels,
)

def forward(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
Expand All @@ -1440,12 +1551,7 @@ def forward(
def encode(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
xs = []
x = self.to_in(x)
for downsample in self.downsamples:
x = downsample(x)
xs += [x]
info = dict(xs=xs)
x, info = self.encoder(x, with_info=True)

for bottleneck in self.bottlenecks:
x, info_bottleneck = bottleneck(x, with_info=True)
Expand All @@ -1454,20 +1560,7 @@ def encode(
return (x, info) if with_info else x

def decode(self, x: Tensor) -> Tensor:
for upsample in self.upsamples:
if self.use_noisy:
x = torch.cat([x, torch.randn_like(x)], dim=1)
x = upsample(x)

if self.use_noisy:
x = torch.cat([x, torch.randn_like(x)], dim=1)

x = self.to_out(x)

if self.use_magnitude_channels:
x = merge_magnitude_channels(x)

return x
return self.decoder(x)


class MultiEncoder1d(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.73",
version="0.0.74",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 24ff00f

Please sign in to comment.