Skip to content

Commit

Permalink
feat: option to norm inputs with mu-law
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 5, 2022
1 parent ed9d77c commit 37903b2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 4 deletions.
25 changes: 22 additions & 3 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from einops_exts import rearrange_many
from torch import Tensor, einsum

from .utils import default, exists, prod
from .utils import default, exists, prod, wave_norm, wave_unnorm

"""
Utils
Expand Down Expand Up @@ -809,6 +809,7 @@ def __init__(
use_nearest_upsample: bool,
use_skip_scale: bool,
use_context_time: bool,
norm: float = 0.0,
out_channels: Optional[int] = None,
context_features: Optional[int] = None,
context_channels: Optional[Sequence[int]] = None,
Expand All @@ -822,6 +823,8 @@ def __init__(
use_context_channels = len(context_channels) > 0
context_mapping_features = None

self.norm = norm
self.use_norm = norm > 0.0
self.num_layers = num_layers
self.use_context_time = use_context_time
self.use_context_features = use_context_features
Expand Down Expand Up @@ -997,9 +1000,11 @@ def forward(
# Concat context channels at layer 0 if provided
channels = self.get_channels(channels_list, layer=0)
x = torch.cat([x, channels], dim=1) if exists(channels) else x

mapping = self.get_mapping(time, features)

if self.use_norm:
x = wave_norm(x, peak=self.norm)

x = self.to_in(x, mapping)
skips_list = [x]

Expand All @@ -1019,6 +1024,9 @@ def forward(
x += skips_list.pop()
x = self.to_out(x, mapping)

if self.use_norm:
x = wave_unnorm(x, peak=self.norm)

return x


Expand Down Expand Up @@ -1120,11 +1128,14 @@ def __init__(
num_blocks: Sequence[int],
use_noisy: bool = False,
bottleneck: Optional[Bottleneck] = None,
norm: float = 0.0,
):
super().__init__()
num_layers = len(multipliers) - 1
self.bottleneck = bottleneck
self.use_noisy = use_noisy
self.use_norm = norm > 0.0
self.norm = norm

assert len(factors) >= num_layers and len(num_blocks) >= num_layers

Expand Down Expand Up @@ -1174,6 +1185,9 @@ def __init__(
def encode(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
if self.use_norm:
x = wave_norm(x, peak=self.norm)

x = self.to_in(x)
for downsample in self.downsamples:
x = downsample(x)
Expand All @@ -1190,7 +1204,12 @@ def decode(self, x: Tensor) -> Tensor:
x = upsample(x)
if self.use_noisy:
x = torch.cat([x, torch.randn_like(x)], dim=1)
return self.to_out(x)
x = self.to_out(x)

if self.use_norm:
x = wave_unnorm(x, peak=self.norm)

return x


class MultiEncoder1d(nn.Module):
Expand Down
13 changes: 13 additions & 0 deletions audio_diffusion_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,16 @@ def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:

def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
return resample(waveforms, factor_in=1, factor_out=factor, **kwargs)


def wave_norm(x: Tensor, bits: int = 24, peak: float = 0.5) -> Tensor:
mu = 2 ** bits
x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / math.log1p(mu)
return x * peak


def wave_unnorm(x: Tensor, bits: int = 24, peak: float = 0.5) -> Tensor:
x = x / peak
mu = 2 ** bits
x = torch.sign(x) * (torch.exp(torch.abs(x) * math.log1p(mu)) - 1) / mu
return x
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.52",
version="0.0.53",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 37903b2

Please sign in to comment.