Skip to content

Commit

Permalink
feat: add xunet, refactor all models with xunet
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Nov 25, 2022
1 parent 5bf1837 commit d730653
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 23 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ from audio_diffusion_pytorch import UNet1d
unet = UNet1d(
in_channels=1,
channels=128,
patch_factor=16,
patch_blocks=1,
patch_size=16,
multipliers=[1, 2, 4, 4, 4, 4, 4],
factors=[4, 4, 4, 2, 2, 2],
attentions=[0, 0, 0, 1, 1, 1, 1],
Expand Down
2 changes: 1 addition & 1 deletion audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@
DiffusionVocoder1d,
Model1d,
)
from .modules import NumberEmbedder, T5Embedder, UNet1d
from .modules import NumberEmbedder, T5Embedder, UNet1d, XUNet1d
34 changes: 15 additions & 19 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tqdm import tqdm

from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetCFG1d, rand_bool
from .modules import STFT, SinusoidalEmbedding, XUNet1d, rand_bool
from .utils import (
closest_power_2,
default,
Expand All @@ -28,18 +28,11 @@


class Model1d(nn.Module):
def __init__(
self, diffusion_type: str, use_classifier_free_guidance: bool = False, **kwargs
):
def __init__(self, unet_type: str = "base", **kwargs):
super().__init__()
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)

UNet = UNetCFG1d if use_classifier_free_guidance else UNet1d
self.unet = UNet(**kwargs)

self.diffusion = XDiffusion(
type=diffusion_type, net=self.unet, **diffusion_kwargs
)
self.unet = XUNet1d(type=unet_type, **kwargs)
self.diffusion = XDiffusion(net=self.unet, **diffusion_kwargs)

def forward(self, x: Tensor, **kwargs) -> Tensor:
return self.diffusion(x, **kwargs)
Expand Down Expand Up @@ -119,10 +112,10 @@ def __init__(
encoder_channels: int,
encoder_factors: Sequence[int],
encoder_multipliers: Sequence[int],
diffusion_type: str,
encoder_patch_size: int = 1,
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
bottleneck_channels: Optional[int] = None,
unet_type: str = "base",
**kwargs,
):
super().__init__()
Expand All @@ -138,13 +131,14 @@ def __init__(
else:
context_channels += [encoder_channels * encoder_multipliers[-1]]

self.unet = UNet1d(
in_channels=in_channels, context_channels=context_channels, **kwargs
self.unet = XUNet1d(
type=unet_type,
in_channels=in_channels,
context_channels=context_channels,
**kwargs,
)

self.diffusion = XDiffusion(
type=diffusion_type, net=self.unet, **diffusion_kwargs
)
self.diffusion = XDiffusion(net=self.unet, **diffusion_kwargs)

self.encoder = Encoder1d(
in_channels=in_channels,
Expand Down Expand Up @@ -207,6 +201,7 @@ def __init__(
encoder_patch_size: int = 1,
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
bottleneck_channels: Optional[int] = None,
unet_type: str = "base",
**kwargs,
):
super().__init__()
Expand All @@ -233,7 +228,8 @@ def __init__(
use_complex=False, # Magnitude encoding
)

self.unet = UNet1d(
self.unet = XUNet1d(
type=unet_type,
in_channels=in_channels,
context_channels=context_channels,
use_stft=True,
Expand Down Expand Up @@ -546,9 +542,9 @@ def __init__(
self.embedding_mask_proba = embedding_mask_proba
default_kwargs = dict(
**get_default_model_kwargs(),
unet_type="cfg",
context_embedding_features=embedding_features,
context_embedding_max_length=embedding_max_length,
use_classifier_free_guidance=True,
)
super().__init__(**{**default_kwargs, **kwargs})

Expand Down
13 changes: 13 additions & 0 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,19 @@ def forward(self, *args, **kwargs): # type: ignore
return UNetCFG1d.forward(self, *args, **kwargs)


def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
if type == "base":
return UNet1d(**kwargs)
elif type == "all":
return UNetAll1d(**kwargs)
elif type == "cfg":
return UNetCFG1d(**kwargs)
elif type == "ncca":
return UNetNCCA1d(**kwargs)
else:
raise ValueError(f"Unknown XUNet1d type: {type}")


class T5Embedder(nn.Module):
def __init__(self, model: str = "t5-base", max_length: int = 64):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.94",
version="0.0.95",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit d730653

Please sign in to comment.