diff --git a/rvc/lib/algorithm/generators.py b/rvc/lib/algorithm/generators.py index 75fc2ad1..3b4fc75c 100644 --- a/rvc/lib/algorithm/generators.py +++ b/rvc/lib/algorithm/generators.py @@ -35,14 +35,14 @@ def __init__( super(Generator, self).__init__() self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) - self.conv_pre = torch.nn.Conv1d( - initial_channel, upsample_initial_channel, 7, 1, padding=3 - ) + self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) resblock = ResBlock1 if resblock == "1" else ResBlock2 - - self.ups_and_resblocks = torch.nn.ModuleList() + + self.ups = torch.nn.ModuleList() + self.resblocks = torch.nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups_and_resblocks.append( + self.ups.append( weight_norm( torch.nn.ConvTranspose1d( upsample_initial_channel // (2**i), @@ -57,35 +57,35 @@ def __init__( for j, (k, d) in enumerate( zip(resblock_kernel_sizes, resblock_dilation_sizes) ): - self.ups_and_resblocks.append(resblock(ch, k, d)) + self.resblocks.append(resblock(ch, k, d)) self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) - self.ups_and_resblocks.apply(init_weights) + self.ups.apply(init_weights) if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1) - def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None): - x = self.conv_pre(x) - if g is not None: - x = x + self.cond(g) + def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) - resblock_idx = 0 - for _ in range(self.num_upsamples): - x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) - x = self.ups_and_resblocks[resblock_idx](x) - resblock_idx += 1 - xs = 0 - for _ in range(self.num_kernels): - xs += self.ups_and_resblocks[resblock_idx](x) - resblock_idx += 1 - x = xs / self.num_kernels + for i in range(self.num_upsamples): + x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs == None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels - x = torch.nn.functional.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) + x = torch.nn.functional.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) - return x + return x def __prepare_scriptable__(self): """Prepares the module for scripting.""" @@ -100,8 +100,10 @@ def __prepare_scriptable__(self): def remove_weight_norm(self): """Removes weight normalization from the upsampling and residual blocks.""" - for l in self.ups_and_resblocks: + for l in self.ups: remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() class SineGen(torch.nn.Module):