From eafdfea2a08b0f1c217d733344d572b63a97d4ae Mon Sep 17 00:00:00 2001 From: chomeyama Date: Wed, 9 Feb 2022 11:22:30 +0900 Subject: [PATCH 1/5] add causal option for HiFiGAN --- parallel_wavegan/layers/residual_block.py | 16 ++++++- parallel_wavegan/models/hifigan.py | 55 ++++++++++++++++------- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/parallel_wavegan/layers/residual_block.py b/parallel_wavegan/layers/residual_block.py index e0e9d6d2..4568f0de 100644 --- a/parallel_wavegan/layers/residual_block.py +++ b/parallel_wavegan/layers/residual_block.py @@ -150,6 +150,7 @@ def __init__( use_additional_convs=True, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.1}, + use_causal_conv=False, ): """Initialize HiFiGANResidualBlock module. @@ -161,6 +162,7 @@ def __init__( bias (bool): Whether to add bias parameter in convolution layers. nonlinear_activation (str): Activation function module name. nonlinear_activation_params (dict): Hyperparameters for activation function. + use_causal_conv (bool): Whether to use causal structure. """ super().__init__() @@ -168,8 +170,13 @@ def __init__( self.convs1 = torch.nn.ModuleList() if use_additional_convs: self.convs2 = torch.nn.ModuleList() + self.use_causal_conv = use_causal_conv assert kernel_size % 2 == 1, "Kernel size must be odd number." for dilation in dilations: + if use_causal_conv: + padding = (kernel_size - 1) * dilation + else: + padding = (kernel_size - 1) // 2 * dilation self.convs1 += [ torch.nn.Sequential( getattr(torch.nn, nonlinear_activation)( @@ -182,11 +189,12 @@ def __init__( 1, dilation=dilation, bias=bias, - padding=(kernel_size - 1) // 2 * dilation, + padding=padding, ), ) ] if use_additional_convs: + padding = kernel_size - 1 if use_causal_conv else (kernel_size - 1) // 2 self.convs2 += [ torch.nn.Sequential( getattr(torch.nn, nonlinear_activation)( @@ -199,7 +207,7 @@ def __init__( 1, dilation=1, bias=bias, - padding=(kernel_size - 1) // 2, + padding=padding, ), ) ] @@ -216,7 +224,11 @@ def forward(self, x): """ for idx in range(len(self.convs1)): xt = self.convs1[idx](x) + if self.use_causal_conv: + xt = xt[:, :, : x.size(-1)] if self.use_additional_convs: xt = self.convs2[idx](xt) + if self.use_causal_conv: + xt = xt[:, :, : x.size(-1)] x = xt + x return x diff --git a/parallel_wavegan/models/hifigan.py b/parallel_wavegan/models/hifigan.py index b0b0287e..95f8d3da 100644 --- a/parallel_wavegan/models/hifigan.py +++ b/parallel_wavegan/models/hifigan.py @@ -34,6 +34,7 @@ def __init__( bias=True, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.1}, + use_causal_conv=False, use_weight_norm=True, ): """Initialize HiFiGANGenerator module. @@ -51,6 +52,7 @@ def __init__( bias (bool): Whether to add bias parameter in convolution layers. nonlinear_activation (str): Activation function module name. nonlinear_activation_params (dict): Hyperparameters for activation function. + use_causal_conv (bool): Whether to use causal structure. use_weight_norm (bool): Whether to use weight norm. If set to true, it will be applied to all of the conv layers. @@ -72,25 +74,45 @@ def __init__( 1, padding=(kernel_size - 1) // 2, ) + self.use_causal_conv = use_causal_conv + self.upsample_scales = upsample_scales self.upsamples = torch.nn.ModuleList() self.blocks = torch.nn.ModuleList() for i in range(len(upsample_kernel_sizes)): assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] - self.upsamples += [ - torch.nn.Sequential( - getattr(torch.nn, nonlinear_activation)( - **nonlinear_activation_params - ), - torch.nn.ConvTranspose1d( - channels // (2 ** i), - channels // (2 ** (i + 1)), - upsample_kernel_sizes[i], - upsample_scales[i], - padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, - output_padding=upsample_scales[i] % 2, - ), - ) - ] + if use_causal_conv: + self.upsamples += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.ReplicationPad1d((1, 0)), + torch.nn.ConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + padding=0, + output_padding=0, + ), + ) + ] + else: + self.upsamples += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + torch.nn.ConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, + output_padding=upsample_scales[i] % 2, + ), + ) + ] for j in range(len(resblock_kernel_sizes)): self.blocks += [ ResidualBlock( @@ -101,6 +123,7 @@ def __init__( use_additional_convs=use_additional_convs, nonlinear_activation=nonlinear_activation, nonlinear_activation_params=nonlinear_activation_params, + use_causal_conv=use_causal_conv, ) ] self.output_conv = torch.nn.Sequential( @@ -137,6 +160,8 @@ def forward(self, c): c = self.input_conv(c) for i in range(self.num_upsamples): c = self.upsamples[i](c) + if self.use_causal_conv: + c = c[:, :, self.upsample_scales[i] : -self.upsample_scales[i]] cs = 0.0 # initialize for j in range(self.num_blocks): cs += self.blocks[i * self.num_blocks + j](c) From 6b0cd4efda27b91bb9f72ebdb2f589adebb0526c Mon Sep 17 00:00:00 2001 From: chomeyama Date: Mon, 14 Feb 2022 11:55:19 +0900 Subject: [PATCH 2/5] add test code for causal hifigan --- test/test_hifigan.py | 65 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/test/test_hifigan.py b/test/test_hifigan.py index f65f5ae5..117cf284 100644 --- a/test/test_hifigan.py +++ b/test/test_hifigan.py @@ -40,7 +40,10 @@ def make_hifigan_generator_args(**kwargs): bias=True, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.1}, + pad="ReplicationPad1d", + pad_params={}, use_weight_norm=True, + use_causal_conv=False, ) defaults.update(kwargs) return defaults @@ -153,3 +156,65 @@ def test_hifigan_trainable(dict_g, dict_d, dict_loss): print(model_d) print(model_g) + + +@pytest.mark.parametrize( + "dict_g", + [ + ( + { + "use_causal_conv": True, + "upsample_scales": [5, 5, 4, 3], + "upsample_kernel_sizes": [10, 10, 8, 6], + } + ), + ( + { + "use_causal_conv": True, + "upsample_scales": [8, 8, 2, 2], + "upsample_kernel_sizes": [16, 16, 4, 4], + } + ), + ( + { + "use_causal_conv": True, + "upsample_scales": [4, 5, 4, 3], + "upsample_kernel_sizes": [8, 10, 8, 6], + } + ), + ( + { + "use_causal_conv": True, + "upsample_scales": [4, 4, 2, 2], + "upsample_kernel_sizes": [8, 8, 4, 4], + } + ), + ], +) +def test_causal_hifigan(dict_g): + batch_size = 4 + batch_length = 8192 + args_g = make_hifigan_generator_args(**dict_g) + upsampling_factor = np.prod(args_g["upsample_scales"]) + c = torch.randn( + batch_size, args_g["in_channels"], batch_length // upsampling_factor + ) + model_g = HiFiGANGenerator(**args_g) + c_ = c.clone() + c_[..., c.size(-1) // 2 :] = torch.randn(c[..., c.size(-1) // 2 :].shape) + try: + # check not equal + np.testing.assert_array_equal(c.numpy(), c_.numpy()) + except AssertionError: + pass + else: + raise AssertionError("Must be different.") + + # check causality + y = model_g(c) + y_ = model_g(c_) + assert y.size(2) == c.size(2) * upsampling_factor + np.testing.assert_array_equal( + y[..., : c.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(), + y_[..., : c_.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(), + ) From 25c4b9a02b21ef1d464e61101ec6ce41014dbaa2 Mon Sep 17 00:00:00 2001 From: chomeyama Date: Mon, 14 Feb 2022 11:57:07 +0900 Subject: [PATCH 3/5] add padding on the left side for causal transposed conv1d --- parallel_wavegan/layers/causal_conv.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/parallel_wavegan/layers/causal_conv.py b/parallel_wavegan/layers/causal_conv.py index abf51b8e..74194930 100644 --- a/parallel_wavegan/layers/causal_conv.py +++ b/parallel_wavegan/layers/causal_conv.py @@ -45,9 +45,21 @@ def forward(self, x): class CausalConvTranspose1d(torch.nn.Module): """CausalConvTranspose1d module with customized initialization.""" - def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + bias=True, + pad="ReplicationPad1d", + pad_params={}, + ): """Initialize CausalConvTranspose1d module.""" super(CausalConvTranspose1d, self).__init__() + # NOTE (yoneyama): This padding is to match the number of inputs + # used to calculate the first output sample with the others. + self.pad = getattr(torch.nn, pad)((1, 0), **pad_params) self.deconv = torch.nn.ConvTranspose1d( in_channels, out_channels, kernel_size, stride, bias=bias ) @@ -63,4 +75,4 @@ def forward(self, x): Tensor: Output tensor (B, out_channels, T_out). """ - return self.deconv(x)[:, :, : -self.stride] + return self.deconv(self.pad(x))[:, :, self.stride : -self.stride] From 3886c44a32197dd19393fac0c590ed7eac3b42a9 Mon Sep 17 00:00:00 2001 From: chomeyama Date: Mon, 14 Feb 2022 12:00:14 +0900 Subject: [PATCH 4/5] add padding option and modified causal option for hifigan --- parallel_wavegan/layers/residual_block.py | 69 ++++++++---- parallel_wavegan/models/hifigan.py | 123 +++++++++++++--------- 2 files changed, 121 insertions(+), 71 deletions(-) diff --git a/parallel_wavegan/layers/residual_block.py b/parallel_wavegan/layers/residual_block.py index 4568f0de..3552e293 100644 --- a/parallel_wavegan/layers/residual_block.py +++ b/parallel_wavegan/layers/residual_block.py @@ -13,6 +13,8 @@ import torch import torch.nn.functional as F +from parallel_wavegan.layers.causal_conv import CausalConv1d + class Conv1d(torch.nn.Conv1d): """Conv1d module with customized initialization.""" @@ -150,6 +152,8 @@ def __init__( use_additional_convs=True, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.1}, + pad="ReplicationPad1d", + pad_params={}, use_causal_conv=False, ): """Initialize HiFiGANResidualBlock module. @@ -162,6 +166,8 @@ def __init__( bias (bool): Whether to add bias parameter in convolution layers. nonlinear_activation (str): Activation function module name. nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before convolution layer. + pad_params (dict): Hyperparameters for padding function. use_causal_conv (bool): Whether to use causal structure. """ @@ -173,43 +179,66 @@ def __init__( self.use_causal_conv = use_causal_conv assert kernel_size % 2 == 1, "Kernel size must be odd number." for dilation in dilations: - if use_causal_conv: - padding = (kernel_size - 1) * dilation - else: - padding = (kernel_size - 1) // 2 * dilation - self.convs1 += [ - torch.nn.Sequential( - getattr(torch.nn, nonlinear_activation)( - **nonlinear_activation_params + if not use_causal_conv: + conv = torch.nn.Sequential( + getattr(torch.nn, pad)( + (kernel_size - 1) // 2 * dilation, **pad_params ), torch.nn.Conv1d( channels, channels, kernel_size, - 1, dilation=dilation, bias=bias, - padding=padding, ), ) + else: + conv = CausalConv1d( + channels, + channels, + kernel_size, + dilation=dilation, + bias=bias, + pad=pad, + pad_params=pad_params, + ) + self.convs1 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + conv, + ) ] if use_additional_convs: - padding = kernel_size - 1 if use_causal_conv else (kernel_size - 1) // 2 - self.convs2 += [ - torch.nn.Sequential( - getattr(torch.nn, nonlinear_activation)( - **nonlinear_activation_params - ), + if not use_causal_conv: + conv = torch.nn.Sequential( + getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params), torch.nn.Conv1d( channels, channels, kernel_size, - 1, dilation=1, bias=bias, - padding=padding, ), ) + else: + conv = CausalConv1d( + channels, + channels, + kernel_size, + dilation=1, + bias=bias, + pad=pad, + pad_params=pad_params, + ) + self.convs2 += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + conv, + ) ] def forward(self, x): @@ -224,11 +253,7 @@ def forward(self, x): """ for idx in range(len(self.convs1)): xt = self.convs1[idx](x) - if self.use_causal_conv: - xt = xt[:, :, : x.size(-1)] if self.use_additional_convs: xt = self.convs2[idx](xt) - if self.use_causal_conv: - xt = xt[:, :, : x.size(-1)] x = xt + x return x diff --git a/parallel_wavegan/models/hifigan.py b/parallel_wavegan/models/hifigan.py index 95f8d3da..f039b9ca 100644 --- a/parallel_wavegan/models/hifigan.py +++ b/parallel_wavegan/models/hifigan.py @@ -14,6 +14,8 @@ import torch.nn.functional as F from parallel_wavegan.layers import HiFiGANResidualBlock as ResidualBlock +from parallel_wavegan.layers import CausalConv1d +from parallel_wavegan.layers import CausalConvTranspose1d from parallel_wavegan.utils import read_hdf5 @@ -34,6 +36,8 @@ def __init__( bias=True, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.1}, + pad="ReplicationPad1d", + pad_params={}, use_causal_conv=False, use_weight_norm=True, ): @@ -52,6 +56,8 @@ def __init__( bias (bool): Whether to add bias parameter in convolution layers. nonlinear_activation (str): Activation function module name. nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before convolution layer. + pad_params (dict): Hyperparameters for padding function. use_causal_conv (bool): Whether to use causal structure. use_weight_norm (bool): Whether to use weight norm. If set to true, it will be applied to all of the conv layers. @@ -67,52 +73,58 @@ def __init__( # define modules self.num_upsamples = len(upsample_kernel_sizes) self.num_blocks = len(resblock_kernel_sizes) - self.input_conv = torch.nn.Conv1d( - in_channels, - channels, - kernel_size, - 1, - padding=(kernel_size - 1) // 2, - ) self.use_causal_conv = use_causal_conv - self.upsample_scales = upsample_scales + if not use_causal_conv: + self.input_conv = torch.nn.Sequential( + getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params), + torch.nn.Conv1d( + in_channels, + channels, + kernel_size, + bias=bias, + ), + ) + else: + self.input_conv = CausalConv1d( + in_channels, + channels, + kernel_size, + bias=bias, + pad=pad, + pad_params=pad_params, + ) self.upsamples = torch.nn.ModuleList() self.blocks = torch.nn.ModuleList() for i in range(len(upsample_kernel_sizes)): assert upsample_kernel_sizes[i] == 2 * upsample_scales[i] - if use_causal_conv: - self.upsamples += [ - torch.nn.Sequential( - getattr(torch.nn, nonlinear_activation)( - **nonlinear_activation_params - ), - torch.nn.ReplicationPad1d((1, 0)), - torch.nn.ConvTranspose1d( - channels // (2 ** i), - channels // (2 ** (i + 1)), - upsample_kernel_sizes[i], - upsample_scales[i], - padding=0, - output_padding=0, - ), - ) - ] + if not use_causal_conv: + conv = torch.nn.ConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, + output_padding=upsample_scales[i] % 2, + bias=bias, + ) else: - self.upsamples += [ - torch.nn.Sequential( - getattr(torch.nn, nonlinear_activation)( - **nonlinear_activation_params - ), - torch.nn.ConvTranspose1d( - channels // (2 ** i), - channels // (2 ** (i + 1)), - upsample_kernel_sizes[i], - upsample_scales[i], - padding=upsample_scales[i] // 2 + upsample_scales[i] % 2, - output_padding=upsample_scales[i] % 2, - ), - ) - ] + conv = CausalConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_kernel_sizes[i], + upsample_scales[i], + bias=bias, + pad=pad, + pad_params=pad_params, + ) + self.upsamples += [ + torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)( + **nonlinear_activation_params + ), + conv, + ) + ] for j in range(len(resblock_kernel_sizes)): self.blocks += [ ResidualBlock( @@ -123,20 +135,35 @@ def __init__( use_additional_convs=use_additional_convs, nonlinear_activation=nonlinear_activation, nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, use_causal_conv=use_causal_conv, ) ] + if not use_causal_conv: + conv = torch.nn.Sequential( + getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params), + torch.nn.Conv1d( + channels // (2 ** (i + 1)), + out_channels, + kernel_size, + bias=bias, + ), + ) + else: + conv = CausalConv1d( + channels // (2 ** (i + 1)), + out_channels, + kernel_size, + bias=bias, + pad=pad, + pad_params=pad_params, + ) self.output_conv = torch.nn.Sequential( # NOTE(kan-bayashi): follow official implementation but why # using different slope parameter here? (0.1 vs. 0.01) torch.nn.LeakyReLU(), - torch.nn.Conv1d( - channels // (2 ** (i + 1)), - out_channels, - kernel_size, - 1, - padding=(kernel_size - 1) // 2, - ), + conv, torch.nn.Tanh(), ) @@ -160,8 +187,6 @@ def forward(self, c): c = self.input_conv(c) for i in range(self.num_upsamples): c = self.upsamples[i](c) - if self.use_causal_conv: - c = c[:, :, self.upsample_scales[i] : -self.upsample_scales[i]] cs = 0.0 # initialize for j in range(self.num_blocks): cs += self.blocks[i * self.num_blocks + j](c) From 76749ed878832a2398851a4541a53cadc85d7661 Mon Sep 17 00:00:00 2001 From: chomeyama Date: Sat, 19 Feb 2022 03:05:15 +0900 Subject: [PATCH 5/5] fixed imports in alphabetical order --- parallel_wavegan/models/hifigan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parallel_wavegan/models/hifigan.py b/parallel_wavegan/models/hifigan.py index f039b9ca..d6b6e62e 100644 --- a/parallel_wavegan/models/hifigan.py +++ b/parallel_wavegan/models/hifigan.py @@ -13,9 +13,9 @@ import torch import torch.nn.functional as F -from parallel_wavegan.layers import HiFiGANResidualBlock as ResidualBlock from parallel_wavegan.layers import CausalConv1d from parallel_wavegan.layers import CausalConvTranspose1d +from parallel_wavegan.layers import HiFiGANResidualBlock as ResidualBlock from parallel_wavegan.utils import read_hdf5