diff --git a/rvc/infer/pipeline.py b/rvc/infer/pipeline.py index 8a988340..7c8f58c1 100644 --- a/rvc/infer/pipeline.py +++ b/rvc/infer/pipeline.py @@ -417,41 +417,62 @@ def voice_conversion( with torch.no_grad(): pitch_guidance = pitch != None and pitchf != None # prepare source audio - feats = torch.from_numpy(audio0).half() if self.is_half else torch.from_numpy(audio0).float() + feats = ( + torch.from_numpy(audio0).half() + if self.is_half + else torch.from_numpy(audio0).float() + ) feats = feats.mean(-1) if feats.dim() == 2 else feats assert feats.dim() == 1, feats.dim() feats = feats.view(1, -1).to(self.device) # extract features feats = model(feats)["last_hidden_state"] - feats = model.final_proj(feats[0]).unsqueeze(0) if version == "v1" else feats + feats = ( + model.final_proj(feats[0]).unsqueeze(0) if version == "v1" else feats + ) # make a copy for pitch guidance and protection feats0 = feats.clone() if pitch_guidance else None - if index: # set by parent function, only true if index is available, loaded, and index rate > 0 - feats = self._retrieve_speaker_embeddings(feats, index, big_npy, index_rate) - # feature upsampling - feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) + if ( + index + ): # set by parent function, only true if index is available, loaded, and index rate > 0 + feats = self._retrieve_speaker_embeddings( + feats, index, big_npy, index_rate + ) + # feature upsampling + feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute( + 0, 2, 1 + ) # adjust the length if the audio is short p_len = min(audio0.shape[0] // self.window, feats.shape[1]) if pitch_guidance: - feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) + feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute( + 0, 2, 1 + ) pitch, pitchf = pitch[:, :p_len], pitchf[:, :p_len] # Pitch protection blending if protect < 0.5: pitchff = pitchf.clone() pitchff[pitchf > 0] = 1 pitchff[pitchf < 1] = protect - feats = feats * pitchff.unsqueeze(-1) + feats0 * (1 - pitchff.unsqueeze(-1)) + feats = feats * pitchff.unsqueeze(-1) + feats0 * ( + 1 - pitchff.unsqueeze(-1) + ) feats = feats.to(feats0.dtype) else: pitch, pitchf = None, None p_len = torch.tensor([p_len], device=self.device).long() - audio1 = ((net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0]).data.cpu().float().numpy()) + audio1 = ( + (net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0]) + .data.cpu() + .float() + .numpy() + ) # clean up del feats, feats0, p_len if torch.cuda.is_available(): torch.cuda.empty_cache() return audio1 - + def _retrieve_speaker_embeddings(self, feats, index, big_npy, index_rate): npy = feats[0].cpu().numpy() npy = npy.astype("float32") if self.is_half else npy @@ -460,9 +481,12 @@ def _retrieve_speaker_embeddings(self, feats, index, big_npy, index_rate): weight /= weight.sum(axis=1, keepdims=True) npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) npy = npy.astype("float16") if self.is_half else npy - feats = torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats + feats = ( + torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + + (1 - index_rate) * feats + ) return feats - + def pipeline( self, model, diff --git a/rvc/train/data_utils.py b/rvc/train/data_utils.py index a4e7ef0f..b6c3d644 100644 --- a/rvc/train/data_utils.py +++ b/rvc/train/data_utils.py @@ -6,6 +6,7 @@ from mel_processing import spectrogram_torch from utils import load_filepaths_and_text, load_wav_to_torch + class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset): """ Dataset that loads text and audio pairs. @@ -163,6 +164,7 @@ def __len__(self): """ return len(self.audiopaths_and_text) + class TextAudioCollateMultiNSFsid: """ Collates text and audio data for training. diff --git a/rvc/train/train.py b/rvc/train/train.py index 4054cad2..4f4ce221 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -148,11 +148,15 @@ def main(): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) # Check sample rate - wavs = glob.glob(os.path.join(os.path.join(experiment_dir, "sliced_audios"), "*.wav")) + wavs = glob.glob( + os.path.join(os.path.join(experiment_dir, "sliced_audios"), "*.wav") + ) if wavs: _, sr = load_wav_to_torch(wavs[0]) if sr != sample_rate: - print(f"Error: Pretrained model sample rate ({sample_rate} Hz) does not match dataset audio sample rate ({sr} Hz).") + print( + f"Error: Pretrained model sample rate ({sample_rate} Hz) does not match dataset audio sample rate ({sr} Hz)." + ) os._exit(1) else: print("No wav file found.") @@ -240,10 +244,12 @@ def continue_overtrain_detector(training_file_path): ) = load_from_json(training_file_path) if sync_graph: - print("Sync graph is now activated! With sync graph enabled, the model undergoes a single epoch of training. Once the graphs are synchronized, training proceeds for the previously specified number of epochs.") + print( + "Sync graph is now activated! With sync graph enabled, the model undergoes a single epoch of training. Once the graphs are synchronized, training proceeds for the previously specified number of epochs." + ) custom_total_epoch = 1 custom_save_every_weights = True - + start() # Synchronize graphs by modifying config files @@ -288,14 +294,18 @@ def edit_config(config_file): edit_config(rvc_config_file) # Clean up unnecessary files - for root, dirs, files in os.walk(os.path.join(now_dir, "logs", model_name), topdown=False): + for root, dirs, files in os.walk( + os.path.join(now_dir, "logs", model_name), topdown=False + ): for name in files: file_path = os.path.join(root, name) file_name, file_extension = os.path.splitext(name) - if (file_extension == ".0" or - (file_name.startswith("D_") and file_extension == ".pth") or - (file_name.startswith("G_") and file_extension == ".pth") or - (file_name.startswith("added") and file_extension == ".index")): + if ( + file_extension == ".0" + or (file_name.startswith("D_") and file_extension == ".pth") + or (file_name.startswith("G_") and file_extension == ".pth") + or (file_name.startswith("added") and file_extension == ".index") + ): os.remove(file_path) for name in dirs: if name == "eval": @@ -317,6 +327,7 @@ def edit_config(config_file): continue_overtrain_detector(training_file_path) start() + def run( rank, n_gpus, @@ -409,8 +420,18 @@ def run( else: net_d = MultiPeriodDiscriminatorV2(config.model.use_spectral_norm).to(device) - optim_g = torch.optim.AdamW(net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps) - optim_d = torch.optim.AdamW(net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps) + optim_g = torch.optim.AdamW( + net_g.parameters(), + config.train.learning_rate, + betas=config.train.betas, + eps=config.train.eps, + ) + optim_d = torch.optim.AdamW( + net_d.parameters(), + config.train.learning_rate, + betas=config.train.betas, + eps=config.train.eps, + ) # Wrap models with DDP for multi-gpu processing if n_gpus > 1 and device.type == "cuda": @@ -420,8 +441,12 @@ def run( # Load checkpoint if available try: print("Starting training...") - _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "D_*.pth"), net_d, optim_d) - _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "G_*.pth"), net_g, optim_g) + _, _, _, epoch_str = load_checkpoint( + latest_checkpoint_path(experiment_dir, "D_*.pth"), net_d, optim_d + ) + _, _, _, epoch_str = load_checkpoint( + latest_checkpoint_path(experiment_dir, "G_*.pth"), net_g, optim_g + ) epoch_str += 1 global_step = (epoch_str - 1) * len(train_loader) @@ -433,21 +458,33 @@ def run( verify_checkpoint_shapes(pretrainG, net_g) print(f"Loaded pretrained (G) '{pretrainG}'") if hasattr(net_g, "module"): - net_g.module.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"]) + net_g.module.load_state_dict( + torch.load(pretrainG, map_location="cpu")["model"] + ) else: - net_g.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"]) + net_g.load_state_dict( + torch.load(pretrainG, map_location="cpu")["model"] + ) if pretrainD != "": if rank == 0: print(f"Loaded pretrained (D) '{pretrainD}'") if hasattr(net_d, "module"): - net_d.module.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"]) + net_d.module.load_state_dict( + torch.load(pretrainD, map_location="cpu")["model"] + ) else: - net_d.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"]) + net_d.load_state_dict( + torch.load(pretrainD, map_location="cpu")["model"] + ) # Initialize schedulers and scaler - scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2) + scheduler_g = torch.optim.lr_scheduler.ExponentialLR( + optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2 + ) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR( + optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2 + ) optim_g.step() optim_d.step() @@ -489,6 +526,7 @@ def run( scheduler_g.step() scheduler_d.step() + def train_and_evaluate( rank, epoch, @@ -542,21 +580,39 @@ def train_and_evaluate( data_iterator = cache if cache == []: for batch_idx, info in enumerate(train_loader): - phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid = info + ( + phone, + phone_lengths, + pitch, + pitchf, + spec, + spec_lengths, + wave, + wave_lengths, + sid, + ) = info cache.append( ( batch_idx, ( phone.cuda(rank, non_blocking=True), - phone_lengths.cuda(rank, non_blocking=True), - pitch.cuda(rank, non_blocking=True) if pitch_guidance else None, - pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None, + phone_lengths.cuda(rank, non_blocking=True), + ( + pitch.cuda(rank, non_blocking=True) + if pitch_guidance + else None + ), + ( + pitchf.cuda(rank, non_blocking=True) + if pitch_guidance + else None + ), spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True), wave.cuda(rank, non_blocking=True), wave_lengths.cuda(rank, non_blocking=True), - sid.cuda(rank, non_blocking=True) - ) + sid.cuda(rank, non_blocking=True), + ), ) ) else: @@ -567,12 +623,24 @@ def train_and_evaluate( epoch_recorder = EpochRecorder() with tqdm(total=len(train_loader), leave=False) as pbar: for batch_idx, info in data_iterator: - phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid = info + ( + phone, + phone_lengths, + pitch, + pitchf, + spec, + spec_lengths, + wave, + wave_lengths, + sid, + ) = info if device.type == "cuda" and not cache_data_in_gpu: phone = phone.cuda(rank, non_blocking=True) phone_lengths = phone_lengths.cuda(rank, non_blocking=True) pitch = pitch.cuda(rank, non_blocking=True) if pitch_guidance else None - pitchf = pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None + pitchf = ( + pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None + ) sid = sid.cuda(rank, non_blocking=True) spec = spec.cuda(rank, non_blocking=True) spec_lengths = spec_lengths.cuda(rank, non_blocking=True) @@ -592,16 +660,27 @@ def train_and_evaluate( # Forward pass use_amp = config.train.fp16_run and device.type == "cuda" with autocast(enabled=use_amp): - (y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q)) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid) + ( + y_hat, + ids_slice, + x_mask, + z_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid) mel = spec_to_mel_torch( spec, config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.mel_fmin, - config.data.mel_fmax + config.data.mel_fmax, + ) + y_mel = commons.slice_segments( + mel, + ids_slice, + config.train.segment_size // config.data.hop_length, + dim=3, ) - y_mel = commons.slice_segments(mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3) with autocast(enabled=False): y_hat_mel = mel_spectrogram_torch( y_hat.float().squeeze(1), @@ -611,14 +690,21 @@ def train_and_evaluate( config.data.hop_length, config.data.win_length, config.data.mel_fmin, - config.data.mel_fmax + config.data.mel_fmax, ) if use_amp: y_hat_mel = y_hat_mel.half() - wave = commons.slice_segments(wave, ids_slice * config.data.hop_length, config.train.segment_size, dim=3) + wave = commons.slice_segments( + wave, + ids_slice * config.data.hop_length, + config.train.segment_size, + dim=3, + ) y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach()) with autocast(enabled=False): - loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, y_d_hat_g + ) # Discriminator backward and update optim_d.zero_grad() scaler.scale(loss_disc).backward() @@ -631,7 +717,9 @@ def train_and_evaluate( y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat) with autocast(enabled=False): loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel - loss_kl = (kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl) + loss_kl = ( + kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl + ) loss_fm = feature_loss(fmap_r, fmap_g) loss_gen, losses_gen = generator_loss(y_d_hat_g) loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl @@ -642,7 +730,9 @@ def train_and_evaluate( lowest_value["epoch"] = epoch # print(f'Lowest generator loss updated: {lowest_value["value"]} at epoch {epoch}, step {global_step}') if epoch > lowest_value["epoch"]: - print("Alert: The lower generating loss has been exceeded by a lower loss in a subsequent epoch.") + print( + "Alert: The lower generating loss has been exceeded by a lower loss in a subsequent epoch." + ) optim_g.zero_grad() scaler.scale(loss_gen_all).backward() @@ -666,13 +756,29 @@ def train_and_evaluate( "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g, } - scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}) - scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)}) - scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)}) - scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)}) + scalar_dict.update( + { + "loss/g/fm": loss_fm, + "loss/g/mel": loss_mel, + "loss/g/kl": loss_kl, + } + ) + scalar_dict.update( + {f"loss/g/{i}": v for i, v in enumerate(losses_gen)} + ) + scalar_dict.update( + {f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)} + ) + scalar_dict.update( + {f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)} + ) image_dict = { - "slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), - "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), + "slice/mel_org": plot_spectrogram_to_numpy( + y_mel[0].data.cpu().numpy() + ), + "slice/mel_gen": plot_spectrogram_to_numpy( + y_hat_mel[0].data.cpu().numpy() + ), "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), } summarize( @@ -689,15 +795,31 @@ def train_and_evaluate( model_add = [] model_del = [] done = False - + if rank == 0: # Save weights every N epochs if epoch % save_every_epoch == 0: - checkpoint_suffix = (f"{2333333 if save_only_latest else global_step}.pth") - save_checkpoint(net_g, optim_g, config.train.learning_rate, epoch, os.path.join(experiment_dir, "G_" + checkpoint_suffix)) - save_checkpoint(net_d, optim_d, config.train.learning_rate, epoch, os.path.join(experiment_dir, "D_" + checkpoint_suffix)) + checkpoint_suffix = f"{2333333 if save_only_latest else global_step}.pth" + save_checkpoint( + net_g, + optim_g, + config.train.learning_rate, + epoch, + os.path.join(experiment_dir, "G_" + checkpoint_suffix), + ) + save_checkpoint( + net_d, + optim_d, + config.train.learning_rate, + epoch, + os.path.join(experiment_dir, "D_" + checkpoint_suffix), + ) if custom_save_every_weights: - model_add.append(os.path.join(experiment_dir, f"{model_name}_{epoch}e_{global_step}s.pth")) + model_add.append( + os.path.join( + experiment_dir, f"{model_name}_{epoch}e_{global_step}s.pth" + ) + ) overtrain_info = "" # Check overtraining if overtraining_detector and rank == 0 and epoch > 1: @@ -705,9 +827,13 @@ def train_and_evaluate( current_loss_disc = float(loss_disc) loss_disc_history.append(current_loss_disc) # Update smoothed loss history with loss_disc - smoothed_value_disc = update_exponential_moving_average(smoothed_loss_disc_history, current_loss_disc) + smoothed_value_disc = update_exponential_moving_average( + smoothed_loss_disc_history, current_loss_disc + ) # Check overtraining with smoothed loss_disc - is_overtraining_disc = check_overtraining(smoothed_loss_disc_history, overtraining_threshold * 2) + is_overtraining_disc = check_overtraining( + smoothed_loss_disc_history, overtraining_threshold * 2 + ) if is_overtraining_disc: consecutive_increases_disc += 1 else: @@ -716,9 +842,13 @@ def train_and_evaluate( current_loss_gen = float(lowest_value["value"]) loss_gen_history.append(current_loss_gen) # Update the smoothed loss_gen history - smoothed_value_gen = update_exponential_moving_average(smoothed_loss_gen_history, current_loss_gen) + smoothed_value_gen = update_exponential_moving_average( + smoothed_loss_gen_history, current_loss_gen + ) # Check for overtraining with the smoothed loss_gen - is_overtraining_gen = check_overtraining(smoothed_loss_gen_history, overtraining_threshold, 0.01) + is_overtraining_gen = check_overtraining( + smoothed_loss_gen_history, overtraining_threshold, 0.01 + ) if is_overtraining_gen: consecutive_increases_gen += 1 else: @@ -726,24 +856,50 @@ def train_and_evaluate( overtrain_info = f"Smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}" # Save the data in the JSON file if the epoch is divisible by save_every_epoch if epoch % save_every_epoch == 0: - save_to_json(training_file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history) + save_to_json( + training_file_path, + loss_disc_history, + smoothed_loss_disc_history, + loss_gen_history, + smoothed_loss_gen_history, + ) - if is_overtraining_gen and consecutive_increases_gen == overtraining_threshold or is_overtraining_disc and consecutive_increases_disc == overtraining_threshold * 2: - print(f"Overtraining detected at epoch {epoch} with smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}") + if ( + is_overtraining_gen + and consecutive_increases_gen == overtraining_threshold + or is_overtraining_disc + and consecutive_increases_disc == overtraining_threshold * 2 + ): + print( + f"Overtraining detected at epoch {epoch} with smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}" + ) done = True else: - print(f"New best epoch {epoch} with smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}") - old_model_files = glob.glob(os.path.join(experiment_dir, f"{model_name}_*e_*s_best_epoch.pth")) + print( + f"New best epoch {epoch} with smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}" + ) + old_model_files = glob.glob( + os.path.join(experiment_dir, f"{model_name}_*e_*s_best_epoch.pth") + ) for file in old_model_files: model_del.append(file) - model_add.append(os.path.join(experiment_dir, f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth")) + model_add.append( + os.path.join( + experiment_dir, + f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth", + ) + ) # Check completion if epoch >= custom_total_epoch: lowest_value_rounded = float(lowest_value["value"]) lowest_value_rounded = round(lowest_value_rounded, 3) - print(f"Training has been successfully completed with {epoch} epoch, {global_step} steps and {round(loss_gen_all.item(), 3)} loss gen.") - print(f"Lowest generator loss: {lowest_value_rounded} at epoch {lowest_value['epoch']}, step {lowest_value['step']}") + print( + f"Training has been successfully completed with {epoch} epoch, {global_step} steps and {round(loss_gen_all.item(), 3)} loss gen." + ) + print( + f"Lowest generator loss: {lowest_value_rounded} at epoch {lowest_value['epoch']}, step {lowest_value['step']}" + ) pid_file_path = os.path.join(experiment_dir, "config.json") with open(pid_file_path, "r") as pid_file: @@ -752,11 +908,19 @@ def train_and_evaluate( pid_data.pop("process_pids", None) json.dump(pid_data, pid_file, indent=4) # Final model - model_add.append(os.path.join(experiment_dir, f"{model_name}_{epoch}e_{global_step}s.pth")) + model_add.append( + os.path.join( + experiment_dir, f"{model_name}_{epoch}e_{global_step}s.pth" + ) + ) done = True - + if model_add: - ckpt = net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict() + ckpt = ( + net_g.module.state_dict() + if hasattr(net_g, "module") + else net_g.state_dict() + ) for m in model_add: if not os.path.exists(m): extract_model( @@ -769,7 +933,7 @@ def train_and_evaluate( step=global_step, version=version, hps=hps, - overtrain_info=overtrain_info + overtrain_info=overtrain_info, ) # Clean-up old best epochs for m in model_del: @@ -781,18 +945,27 @@ def train_and_evaluate( record = f"{model_name} | epoch={epoch} | step={global_step} | {epoch_recorder.record()}" if epoch > 1: - record = record + f" | lowest_value={lowest_value_rounded} (epoch {lowest_value['epoch']} and step {lowest_value['step']})" - + record = ( + record + + f" | lowest_value={lowest_value_rounded} (epoch {lowest_value['epoch']} and step {lowest_value['step']})" + ) + if overtraining_detector: remaining_epochs_gen = overtraining_threshold - consecutive_increases_gen - remaining_epochs_disc = overtraining_threshold * 2 - consecutive_increases_disc - record = record + f" | Number of epochs remaining for overtraining: g/total: {remaining_epochs_gen} d/total: {remaining_epochs_disc} | smoothed_loss_gen={smoothed_value_gen:.3f} | smoothed_loss_disc={smoothed_value_disc:.3f}" + remaining_epochs_disc = ( + overtraining_threshold * 2 - consecutive_increases_disc + ) + record = ( + record + + f" | Number of epochs remaining for overtraining: g/total: {remaining_epochs_gen} d/total: {remaining_epochs_disc} | smoothed_loss_gen={smoothed_value_gen:.3f} | smoothed_loss_disc={smoothed_value_disc:.3f}" + ) print(record) last_loss_gen_all = loss_gen_all if done: os._exit(2333333) + def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004): """ Checks for overtraining based on the smoothed loss history. @@ -812,7 +985,10 @@ def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004): return False return True -def update_exponential_moving_average(smoothed_loss_history, new_value, smoothing=0.987): + +def update_exponential_moving_average( + smoothed_loss_history, new_value, smoothing=0.987 +): """ Updates the exponential moving average with a new value. @@ -822,19 +998,22 @@ def update_exponential_moving_average(smoothed_loss_history, new_value, smoothin smoothing (float): Smoothing factor. """ if smoothed_loss_history: - smoothed_value = smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value + smoothed_value = ( + smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value + ) else: smoothed_value = new_value smoothed_loss_history.append(smoothed_value) return smoothed_value + def save_to_json( file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history, - ): +): """ Save the training history to a JSON file. """ @@ -847,6 +1026,7 @@ def save_to_json( with open(file_path, "w") as f: json.dump(data, f) + if __name__ == "__main__": torch.multiprocessing.set_start_method("spawn") main()