Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add causal option for HiFiGAN #326

Merged
merged 6 commits into from
Feb 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions parallel_wavegan/layers/causal_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,21 @@ def forward(self, x):
class CausalConvTranspose1d(torch.nn.Module):
"""CausalConvTranspose1d module with customized initialization."""

def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
bias=True,
pad="ReplicationPad1d",
pad_params={},
):
"""Initialize CausalConvTranspose1d module."""
super(CausalConvTranspose1d, self).__init__()
# NOTE (yoneyama): This padding is to match the number of inputs
# used to calculate the first output sample with the others.
self.pad = getattr(torch.nn, pad)((1, 0), **pad_params)
self.deconv = torch.nn.ConvTranspose1d(
in_channels, out_channels, kernel_size, stride, bias=bias
)
Expand All @@ -63,4 +75,4 @@ def forward(self, x):
Tensor: Output tensor (B, out_channels, T_out).

"""
return self.deconv(x)[:, :, : -self.stride]
return self.deconv(self.pad(x))[:, :, self.stride : -self.stride]
63 changes: 50 additions & 13 deletions parallel_wavegan/layers/residual_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch
import torch.nn.functional as F

from parallel_wavegan.layers.causal_conv import CausalConv1d


class Conv1d(torch.nn.Conv1d):
"""Conv1d module with customized initialization."""
Expand Down Expand Up @@ -150,6 +152,9 @@ def __init__(
use_additional_convs=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
pad="ReplicationPad1d",
pad_params={},
use_causal_conv=False,
):
"""Initialize HiFiGANResidualBlock module.

Expand All @@ -161,47 +166,79 @@ def __init__(
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (dict): Hyperparameters for activation function.
pad (str): Padding function module name before convolution layer.
pad_params (dict): Hyperparameters for padding function.
use_causal_conv (bool): Whether to use causal structure.

"""
super().__init__()
self.use_additional_convs = use_additional_convs
self.convs1 = torch.nn.ModuleList()
if use_additional_convs:
self.convs2 = torch.nn.ModuleList()
self.use_causal_conv = use_causal_conv
assert kernel_size % 2 == 1, "Kernel size must be odd number."
for dilation in dilations:
self.convs1 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
if not use_causal_conv:
conv = torch.nn.Sequential(
getattr(torch.nn, pad)(
(kernel_size - 1) // 2 * dilation, **pad_params
),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation,
bias=bias,
padding=(kernel_size - 1) // 2 * dilation,
),
)
else:
conv = CausalConv1d(
channels,
channels,
kernel_size,
dilation=dilation,
bias=bias,
pad=pad,
pad_params=pad_params,
)
self.convs1 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
conv,
)
]
if use_additional_convs:
self.convs2 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
if not use_causal_conv:
conv = torch.nn.Sequential(
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
torch.nn.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
bias=bias,
padding=(kernel_size - 1) // 2,
),
)
else:
conv = CausalConv1d(
channels,
channels,
kernel_size,
dilation=1,
bias=bias,
pad=pad,
pad_params=pad_params,
)
self.convs2 += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
conv,
)
]

def forward(self, x):
Expand Down
94 changes: 72 additions & 22 deletions parallel_wavegan/models/hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch
import torch.nn.functional as F

from parallel_wavegan.layers import CausalConv1d
from parallel_wavegan.layers import CausalConvTranspose1d
from parallel_wavegan.layers import HiFiGANResidualBlock as ResidualBlock
from parallel_wavegan.utils import read_hdf5

Expand All @@ -34,6 +36,9 @@ def __init__(
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
pad="ReplicationPad1d",
pad_params={},
use_causal_conv=False,
use_weight_norm=True,
):
"""Initialize HiFiGANGenerator module.
Expand All @@ -51,6 +56,9 @@ def __init__(
bias (bool): Whether to add bias parameter in convolution layers.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (dict): Hyperparameters for activation function.
pad (str): Padding function module name before convolution layer.
pad_params (dict): Hyperparameters for padding function.
use_causal_conv (bool): Whether to use causal structure.
use_weight_norm (bool): Whether to use weight norm.
If set to true, it will be applied to all of the conv layers.

Expand All @@ -65,30 +73,56 @@ def __init__(
# define modules
self.num_upsamples = len(upsample_kernel_sizes)
self.num_blocks = len(resblock_kernel_sizes)
self.input_conv = torch.nn.Conv1d(
in_channels,
channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
)
self.use_causal_conv = use_causal_conv
if not use_causal_conv:
self.input_conv = torch.nn.Sequential(
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
torch.nn.Conv1d(
in_channels,
channels,
kernel_size,
bias=bias,
),
)
else:
self.input_conv = CausalConv1d(
in_channels,
channels,
kernel_size,
bias=bias,
pad=pad,
pad_params=pad_params,
)
self.upsamples = torch.nn.ModuleList()
self.blocks = torch.nn.ModuleList()
for i in range(len(upsample_kernel_sizes)):
assert upsample_kernel_sizes[i] == 2 * upsample_scales[i]
if not use_causal_conv:
conv = torch.nn.ConvTranspose1d(
channels // (2 ** i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
output_padding=upsample_scales[i] % 2,
bias=bias,
)
else:
conv = CausalConvTranspose1d(
channels // (2 ** i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
bias=bias,
pad=pad,
pad_params=pad_params,
)
self.upsamples += [
torch.nn.Sequential(
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
torch.nn.ConvTranspose1d(
channels // (2 ** i),
channels // (2 ** (i + 1)),
upsample_kernel_sizes[i],
upsample_scales[i],
padding=upsample_scales[i] // 2 + upsample_scales[i] % 2,
output_padding=upsample_scales[i] % 2,
),
conv,
)
]
for j in range(len(resblock_kernel_sizes)):
Expand All @@ -101,19 +135,35 @@ def __init__(
use_additional_convs=use_additional_convs,
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
pad=pad,
pad_params=pad_params,
use_causal_conv=use_causal_conv,
)
]
if not use_causal_conv:
conv = torch.nn.Sequential(
getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
torch.nn.Conv1d(
channels // (2 ** (i + 1)),
out_channels,
kernel_size,
bias=bias,
),
)
else:
conv = CausalConv1d(
channels // (2 ** (i + 1)),
out_channels,
kernel_size,
bias=bias,
pad=pad,
pad_params=pad_params,
)
self.output_conv = torch.nn.Sequential(
# NOTE(kan-bayashi): follow official implementation but why
# using different slope parameter here? (0.1 vs. 0.01)
torch.nn.LeakyReLU(),
torch.nn.Conv1d(
channels // (2 ** (i + 1)),
out_channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
),
conv,
torch.nn.Tanh(),
)

Expand Down
65 changes: 65 additions & 0 deletions test/test_hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def make_hifigan_generator_args(**kwargs):
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.1},
pad="ReplicationPad1d",
pad_params={},
use_weight_norm=True,
use_causal_conv=False,
)
defaults.update(kwargs)
return defaults
Expand Down Expand Up @@ -153,3 +156,65 @@ def test_hifigan_trainable(dict_g, dict_d, dict_loss):

print(model_d)
print(model_g)


@pytest.mark.parametrize(
"dict_g",
[
(
{
"use_causal_conv": True,
"upsample_scales": [5, 5, 4, 3],
"upsample_kernel_sizes": [10, 10, 8, 6],
}
),
(
{
"use_causal_conv": True,
"upsample_scales": [8, 8, 2, 2],
"upsample_kernel_sizes": [16, 16, 4, 4],
}
),
(
{
"use_causal_conv": True,
"upsample_scales": [4, 5, 4, 3],
"upsample_kernel_sizes": [8, 10, 8, 6],
}
),
(
{
"use_causal_conv": True,
"upsample_scales": [4, 4, 2, 2],
"upsample_kernel_sizes": [8, 8, 4, 4],
}
),
],
)
def test_causal_hifigan(dict_g):
batch_size = 4
batch_length = 8192
args_g = make_hifigan_generator_args(**dict_g)
upsampling_factor = np.prod(args_g["upsample_scales"])
c = torch.randn(
batch_size, args_g["in_channels"], batch_length // upsampling_factor
)
model_g = HiFiGANGenerator(**args_g)
c_ = c.clone()
c_[..., c.size(-1) // 2 :] = torch.randn(c[..., c.size(-1) // 2 :].shape)
try:
# check not equal
np.testing.assert_array_equal(c.numpy(), c_.numpy())
except AssertionError:
pass
else:
raise AssertionError("Must be different.")

# check causality
y = model_g(c)
y_ = model_g(c_)
assert y.size(2) == c.size(2) * upsampling_factor
np.testing.assert_array_equal(
y[..., : c.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(),
y_[..., : c_.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(),
)