Skip to content

Commit

Permalink
Merge pull request #676 from ShiromiyaG/new-pr
Browse files Browse the repository at this point in the history
Fix overtrain
  • Loading branch information
blaisewf authored Sep 6, 2024
2 parents 66c4ad9 + 55aefb9 commit aced175
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def train_and_evaluate(
writers (list): List of TensorBoard writers [writer, writer_eval].
cache (list): List to cache data in GPU memory.
"""
global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc
global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc, smoothed_value_gen, smoothed_value_disc

if epoch == 1:
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
Expand Down Expand Up @@ -858,8 +858,10 @@ def train_and_evaluate(
ckpt = net_g.module.state_dict()
else:
ckpt = net_g.state_dict()
if overtraining_detector != True:
overtrain_info = None
if overtraining_detector and epoch > 1:
overtrain_info = f"Smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}"
else:
overtrain_info = ""
extract_model(
ckpt=ckpt,
sr=sample_rate,
Expand Down

0 comments on commit aced175

Please sign in to comment.