Skip to content

Commit

Permalink
Merge pull request #716 from AznamirWoW/fix_for_nonF0_generator
Browse files Browse the repository at this point in the history
corrected mismatch between generator and pretrained nonF0 weights
  • Loading branch information
blaisewf authored Sep 19, 2024
2 parents 1ceb575 + 87da6ba commit 0cb623a
Showing 1 changed file with 29 additions and 27 deletions.
56 changes: 29 additions & 27 deletions rvc/lib/algorithm/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ def __init__(
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = torch.nn.Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = ResBlock1 if resblock == "1" else ResBlock2

self.ups_and_resblocks = torch.nn.ModuleList()

self.ups = torch.nn.ModuleList()
self.resblocks = torch.nn.ModuleList()

for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups_and_resblocks.append(
self.ups.append(
weight_norm(
torch.nn.ConvTranspose1d(
upsample_initial_channel // (2**i),
Expand All @@ -57,35 +57,35 @@ def __init__(
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
self.ups_and_resblocks.append(resblock(ch, k, d))
self.resblocks.append(resblock(ch, k, d))

self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.ups_and_resblocks.apply(init_weights)
self.ups.apply(init_weights)

if gin_channels != 0:
self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)

def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)

resblock_idx = 0
for _ in range(self.num_upsamples):
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
x = self.ups_and_resblocks[resblock_idx](x)
resblock_idx += 1
xs = 0
for _ in range(self.num_kernels):
xs += self.ups_and_resblocks[resblock_idx](x)
resblock_idx += 1
x = xs / self.num_kernels
for i in range(self.num_upsamples):
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs == None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels

x = torch.nn.functional.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
x = torch.nn.functional.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)

return x
return x

def __prepare_scriptable__(self):
"""Prepares the module for scripting."""
Expand All @@ -100,8 +100,10 @@ def __prepare_scriptable__(self):

def remove_weight_norm(self):
"""Removes weight normalization from the upsampling and residual blocks."""
for l in self.ups_and_resblocks:
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()


class SineGen(torch.nn.Module):
Expand Down

0 comments on commit 0cb623a

Please sign in to comment.