Skip to content

Commit

Permalink
Merge pull request #895 from AznamirWoW/switch_to_signal_instead_of_m…
Browse files Browse the repository at this point in the history
…el_spec

replaced mel loss function with a better alterntive, general fixes for train loop
  • Loading branch information
blaisewf authored Dec 2, 2024
2 parents d3a29a1 + c0763f8 commit 3dbcf51
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 63 deletions.
70 changes: 69 additions & 1 deletion rvc/train/mel_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):

spec = torch.stft(
y,
n_fft,
n_fft=n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
Expand Down Expand Up @@ -144,3 +144,71 @@ def mel_spectrogram_torch(
melspec = spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax)

return melspec

def compute_window_length(n_mels: int, sample_rate: int):
f_min = 0
f_max = sample_rate / 2
window_length_seconds = 8 * n_mels / (f_max - f_min)
window_length = int(window_length_seconds * sample_rate)
return 2**(window_length.bit_length()-1)

class MultiScaleMelSpectrogramLoss(torch.nn.Module):

def __init__(
self,
sample_rate: int = 24000,
n_mels = [5, 10, 20, 40, 80, 160, 320, 480],

loss_fn = torch.nn.L1Loss(),
):
super().__init__()
self.sample_rate = sample_rate
self.loss_fn = loss_fn
self.log_base = torch.log(torch.tensor(10.0))
self.stft_params = {}
self.mel_banks = {}

window_lengths = [compute_window_length(mel, sample_rate) for mel in n_mels]
#print(window_lengths)

for n_mels, window_length in zip(n_mels, window_lengths):
self.stft_params[n_mels] = {
"n_mels": n_mels,
"window_length": window_length,
"hop_length": self.sample_rate // 100,
}
self.mel_banks[n_mels] = torch.from_numpy(
librosa_mel_fn(
sr=self.sample_rate,
n_mels=n_mels,
n_fft=window_length,
fmin=0,
fmax=None,
)
)

def mel_spectrogram(self, wav, n_mels, window_length, hop_length,):
wav = wav.squeeze(1) # -> torch(B, T)
window = torch.hann_window(window_length).to(wav.device).to(wav.dtype)
stft = torch.stft(
wav.float(),
n_fft=window_length,
hop_length=hop_length,
window=window,
return_complex=True,
) # -> torch (B, window_length // 2 + 1, (T - window_length)/hop_length + 1)
magnitude = torch.sqrt(stft.real.pow(2) + stft.imag.pow(2) + 1e-6)
mel_basis = self.mel_banks[n_mels].to(wav.device) # torch(n_mels, window_length // 2 + 1)
mel_spectrogram = torch.matmul(mel_basis, magnitude) # torch(B, n_mels, stft.frames)
return mel_spectrogram

def forward(self, real, fake): # real: torch(B, 1, T) , fake: torch(B, 1, T)
loss = 0.0
for p in self.stft_params.values():
real_mels = self.mel_spectrogram(real, **p)
fake_mels = self.mel_spectrogram(fake, **p)
real_logmels = torch.log(real_mels.clamp(min=1e-5).pow(1)) / self.log_base
fake_logmels = torch.log(fake_mels.clamp(min=1e-5).pow(1)) / self.log_base
loss += self.loss_fn(real_logmels, fake_logmels)
return loss

127 changes: 65 additions & 62 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
generator_loss,
kl_loss,
)
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from mel_processing import mel_spectrogram_torch, spec_to_mel_torch, MultiScaleMelSpectrogramLoss

from rvc.train.process.extract_model import extract_model

Expand Down Expand Up @@ -333,7 +333,7 @@ def run(
train_sampler = DistributedBucketSampler(
train_dataset,
batch_size * n_gpus,
[100, 200, 300, 400, 500, 600, 700, 800, 900],
[50, 100, 200, 300, 400, 500, 600, 700, 800, 900],
num_replicas=n_gpus,
rank=rank,
shuffle=True,
Expand Down Expand Up @@ -377,6 +377,8 @@ def run(
betas=config.train.betas,
eps=config.train.eps,
)

fn_mel_loss = MultiScaleMelSpectrogramLoss(sample_rate=sample_rate)

# Wrap models with DDP for multi-gpu processing
if n_gpus > 1 and device.type == "cuda":
Expand All @@ -398,7 +400,7 @@ def run(
except:
epoch_str = 1
global_step = 0
if pretrainG != "":
if pretrainG != "" and pretrainG != "None":
if rank == 0:
verify_checkpoint_shapes(pretrainG, net_g)
print(f"Loaded pretrained (G) '{pretrainG}'")
Expand All @@ -411,7 +413,7 @@ def run(
torch.load(pretrainG, map_location="cpu")["model"]
)

if pretrainD != "":
if pretrainD != "" and pretrainD != "None":
if rank == 0:
print(f"Loaded pretrained (D) '{pretrainD}'")
if hasattr(net_d, "module"):
Expand Down Expand Up @@ -489,6 +491,7 @@ def run(
custom_total_epoch,
device,
reference,
fn_mel_loss
)

scheduler_g.step()
Expand All @@ -509,6 +512,7 @@ def train_and_evaluate(
custom_total_epoch,
device,
reference,
fn_mel_loss
):
"""
Trains and evaluates the model for one epoch.
Expand Down Expand Up @@ -589,36 +593,6 @@ def train_and_evaluate(
y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q) = (
model_output
)
# used for tensorboard chart - all/mel
mel = spec_to_mel_torch(
spec,
config.data.filter_length,
config.data.n_mel_channels,
config.data.sample_rate,
config.data.mel_fmin,
config.data.mel_fmax,
)
# used for tensorboard chart - slice/mel_org
y_mel = commons.slice_segments(
mel,
ids_slice,
config.train.segment_size // config.data.hop_length,
dim=3,
)
# used for tensorboard chart - slice/mel_gen
with autocast(enabled=False):
y_hat_mel = mel_spectrogram_torch(
y_hat.float().squeeze(1),
config.data.filter_length,
config.data.n_mel_channels,
config.data.sample_rate,
config.data.hop_length,
config.data.win_length,
config.data.mel_fmin,
config.data.mel_fmax,
)
if use_amp:
y_hat_mel = y_hat_mel.half()
# slice of the original waveform to match a generate slice
wave = commons.slice_segments(
wave,
Expand All @@ -640,25 +614,16 @@ def train_and_evaluate(

# Generator backward and update
with autocast(enabled=use_amp):
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
_, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
with autocast(enabled=False):
loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel
loss_kl = (
kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl
)
loss_mel = fn_mel_loss(wave, y_hat) * config.train.c_mel / 3.0
loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl
loss_fm = feature_loss(fmap_r, fmap_g)
loss_gen, losses_gen = generator_loss(y_d_hat_g)
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl

if loss_gen_all < lowest_value["value"]:
lowest_value["value"] = loss_gen_all
lowest_value["step"] = global_step
lowest_value["epoch"] = epoch
# print(f'Lowest generator loss updated: {lowest_value["value"]} at epoch {epoch}, step {global_step}')
if epoch > lowest_value["epoch"]:
print(
"Alert: The lower generating loss has been exceeded by a lower loss in a subsequent epoch."
)
lowest_value = {"step": global_step, "value": loss_gen_all, "epoch": epoch}

optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
Expand All @@ -672,6 +637,37 @@ def train_and_evaluate(

# Logging and checkpointing
if rank == 0:
# used for tensorboard chart - all/mel
mel = spec_to_mel_torch(
spec,
config.data.filter_length,
config.data.n_mel_channels,
config.data.sample_rate,
config.data.mel_fmin,
config.data.mel_fmax,
)
# used for tensorboard chart - slice/mel_org
y_mel = commons.slice_segments(
mel,
ids_slice,
config.train.segment_size // config.data.hop_length,
dim=3,
)
# used for tensorboard chart - slice/mel_gen
with autocast(enabled=False):
y_hat_mel = mel_spectrogram_torch(
y_hat.float().squeeze(1),
config.data.filter_length,
config.data.n_mel_channels,
config.data.sample_rate,
config.data.hop_length,
config.data.win_length,
config.data.mel_fmin,
config.data.mel_fmax,
)
if use_amp:
y_hat_mel = y_hat_mel.half()

lr = optim_g.param_groups[0]["lr"]
if loss_mel > 75:
loss_mel = 75
Expand All @@ -698,21 +694,28 @@ def train_and_evaluate(
"all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
}

with torch.no_grad():
if hasattr(net_g, "module"):
o, *_ = net_g.module.infer(*reference)
else:
o, *_ = net_g.infer(*reference)
audio_dict = {f"gen/audio_{global_step:07d}": o[0, :, :]}

summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
audios=audio_dict,
audio_sample_rate=config.data.sample_rate,
)
if epoch % save_every_epoch == 0:
with torch.no_grad():
if hasattr(net_g, "module"):
o, *_ = net_g.module.infer(*reference)
else:
o, *_ = net_g.infer(*reference)
audio_dict = {f"gen/audio_{global_step:07d}": o[0, :, :]}
summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
audios=audio_dict,
audio_sample_rate=config.data.sample_rate,
)
else:
summarize(
writer=writer,
global_step=global_step,
images=image_dict,
scalars=scalar_dict,
)

# Save checkpoint
model_add = []
Expand Down

0 comments on commit 3dbcf51

Please sign in to comment.