Skip to content

Commit

Permalink
feat: add tanh bottleneck, option to use multiple bottlenecks
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 23, 2022
1 parent e22b4e9 commit cfe358f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
1 change: 1 addition & 0 deletions audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
AutoEncoder1d,
MultiEncoder1d,
T5Embedder,
Tanh,
UNet1d,
UNetConditional1d,
Variational,
Expand Down
21 changes: 15 additions & 6 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from math import pi
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -9,7 +9,7 @@
from einops_exts import rearrange_many
from torch import Tensor, einsum

from .utils import default, exists, prod
from .utils import default, exists, prod, to_list

"""
Utils
Expand Down Expand Up @@ -1341,6 +1341,15 @@ def forward(
return (out, dict(loss=loss, mean=mean, logvar=logvar)) if with_info else out


class Tanh(Bottleneck):
def forward(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
x = torch.tanh(x)
info: Dict = dict()
return (x, info) if with_info else x


class AutoEncoder1d(nn.Module):
def __init__(
self,
Expand All @@ -1353,12 +1362,12 @@ def __init__(
factors: Sequence[int],
num_blocks: Sequence[int],
use_noisy: bool = False,
bottleneck: Optional[Bottleneck] = None,
bottleneck: Union[Bottleneck, List[Bottleneck]] = [],
use_magnitude_channels: bool = False,
):
super().__init__()
num_layers = len(multipliers) - 1
self.bottleneck = bottleneck
self.bottlenecks = to_list(bottleneck)
self.use_noisy = use_noisy
self.use_magnitude_channels = use_magnitude_channels

Expand Down Expand Up @@ -1424,8 +1433,8 @@ def encode(
xs += [x]
info = dict(xs=xs)

if exists(self.bottleneck):
x, info_bottleneck = self.bottleneck(x, with_info=True)
for bottleneck in self.bottlenecks:
x, info_bottleneck = bottleneck(x, with_info=True)
info = {**info, **info_bottleneck}

return (x, info) if with_info else x
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.70",
version="0.0.71",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit cfe358f

Please sign in to comment.