Skip to content

Commit

Permalink
fix mistake in multispeaker training
Browse files Browse the repository at this point in the history
  • Loading branch information
yl4579 authored Nov 3, 2023
1 parent 11cb0cb commit 041ea6d
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions train_second.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,7 @@ def main(config_path):
mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
en = []
gt = []
st = []
p_en = []
wav = []

Expand All @@ -358,18 +359,23 @@ def main(config_path):

y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
wav.append(torch.from_numpy(y).to(device))

# style reference (better to be different from the GT)
random_start = np.random.randint(0, mel_length - mel_len_st)
st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])

wav = torch.stack(wav).float().detach()

en = torch.stack(en)
p_en = torch.stack(p_en)
gt = torch.stack(gt).detach()

st = torch.stack(st).detach()

if gt.size(-1) < 80:
continue

s_dur = model.predictor_encoder(gt.unsqueeze(1))
s = model.style_encoder(gt.unsqueeze(1))
s_dur = model.predictor_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))

with torch.no_grad():
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
Expand Down

0 comments on commit 041ea6d

Please sign in to comment.