diff --git a/core.py b/core.py index afcbaedc..e604f4d2 100644 --- a/core.py +++ b/core.py @@ -517,7 +517,6 @@ def run_train_script( index_algorithm: str = "Auto", cache_data_in_gpu: bool = False, custom_pretrained: bool = False, - use_cpu: bool = False, g_pretrained_path: str = None, d_pretrained_path: str = None, ): @@ -561,7 +560,6 @@ def run_train_script( overtraining_detector, overtraining_threshold, sync_graph, - use_cpu, ], ), ] @@ -1473,13 +1471,6 @@ def parse_arguments(): default="Auto", required=False, ) - train_parser.add_argument( - "--use_cpu", - type=lambda x: bool(strtobool(x)), - choices=[True, False], - help="Force the use of CPU for training.", - default=False, - ) # Parser for 'index' mode index_parser = subparsers.add_parser( @@ -1784,7 +1775,6 @@ def main(): sync_graph=args.sync_graph, index_algorithm=args.index_algorithm, cache_data_in_gpu=args.cache_data_in_gpu, - use_cpu=args.use_cpu, g_pretrained_path=args.g_pretrained_path, d_pretrained_path=args.d_pretrained_path, ) diff --git a/rvc/infer/pipeline.py b/rvc/infer/pipeline.py index 45b9c18b..8a988340 100644 --- a/rvc/infer/pipeline.py +++ b/rvc/infer/pipeline.py @@ -259,7 +259,9 @@ def get_f0_hybrid( for method in methods: f0 = None if method == "crepe": - f0 = self.get_f0_crepe(x, f0_min, f0_max, p_len, int(hop_length)) + f0 = self.get_f0_crepe_computation( + x, f0_min, f0_max, p_len, int(hop_length) + ) elif method == "rmvpe": self.model_rmvpe = RMVPE0Predictor( os.path.join("rvc", "models", "predictors", "rmvpe.pt"), @@ -412,82 +414,44 @@ def voice_conversion( version: Model version ("v1" or "v2"). protect: Protection level for preserving the original pitch. """ - feats = torch.from_numpy(audio0) - if self.is_half: - feats = feats.half() - else: - feats = feats.float() - if feats.dim() == 2: - feats = feats.mean(-1) - assert feats.dim() == 1, feats.dim() - feats = feats.view(1, -1) - padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False) - with torch.no_grad(): - feats = model(feats.to(self.device))["last_hidden_state"] - feats = ( - model.final_proj(feats[0]).unsqueeze(0) if version == "v1" else feats - ) - if protect < 0.5 and pitch != None and pitchf != None: - feats0 = feats.clone() - if ( - isinstance(index, type(None)) == False - and isinstance(big_npy, type(None)) == False - and index_rate != 0 - ): - npy = feats[0].cpu().numpy() - if self.is_half: - npy = npy.astype("float32") - - score, ix = index.search(npy, k=8) - weight = np.square(1 / score) - weight /= weight.sum(axis=1, keepdims=True) - npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) - - if self.is_half: - npy = npy.astype("float16") - feats = ( - torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate - + (1 - index_rate) * feats - ) - - feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) - if protect < 0.5 and pitch != None and pitchf != None: - feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute( - 0, 2, 1 - ) - p_len = audio0.shape[0] // self.window - if feats.shape[1] < p_len: - p_len = feats.shape[1] - if pitch != None and pitchf != None: - pitch = pitch[:, :p_len] - pitchf = pitchf[:, :p_len] - - if protect < 0.5 and pitch != None and pitchf != None: - pitchff = pitchf.clone() - pitchff[pitchf > 0] = 1 - pitchff[pitchf < 1] = protect - pitchff = pitchff.unsqueeze(-1) - feats = feats * pitchff + feats0 * (1 - pitchff) - feats = feats.to(feats0.dtype) - p_len = torch.tensor([p_len], device=self.device).long() - with torch.no_grad(): - if pitch != None and pitchf != None: - audio1 = ( - (net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0]) - .data.cpu() - .float() - .numpy() - ) + pitch_guidance = pitch != None and pitchf != None + # prepare source audio + feats = torch.from_numpy(audio0).half() if self.is_half else torch.from_numpy(audio0).float() + feats = feats.mean(-1) if feats.dim() == 2 else feats + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1).to(self.device) + # extract features + feats = model(feats)["last_hidden_state"] + feats = model.final_proj(feats[0]).unsqueeze(0) if version == "v1" else feats + # make a copy for pitch guidance and protection + feats0 = feats.clone() if pitch_guidance else None + if index: # set by parent function, only true if index is available, loaded, and index rate > 0 + feats = self._retrieve_speaker_embeddings(feats, index, big_npy, index_rate) + # feature upsampling + feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) + # adjust the length if the audio is short + p_len = min(audio0.shape[0] // self.window, feats.shape[1]) + if pitch_guidance: + feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) + pitch, pitchf = pitch[:, :p_len], pitchf[:, :p_len] + # Pitch protection blending + if protect < 0.5: + pitchff = pitchf.clone() + pitchff[pitchf > 0] = 1 + pitchff[pitchf < 1] = protect + feats = feats * pitchff.unsqueeze(-1) + feats0 * (1 - pitchff.unsqueeze(-1)) + feats = feats.to(feats0.dtype) else: - audio1 = ( - (net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy() - ) - del feats, p_len, padding_mask - if torch.cuda.is_available(): - torch.cuda.empty_cache() + pitch, pitchf = None, None + p_len = torch.tensor([p_len], device=self.device).long() + audio1 = ((net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0]).data.cpu().float().numpy()) + # clean up + del feats, feats0, p_len + if torch.cuda.is_available(): + torch.cuda.empty_cache() return audio1 - + def _retrieve_speaker_embeddings(self, feats, index, big_npy, index_rate): npy = feats[0].cpu().numpy() npy = npy.astype("float32") if self.is_half else npy @@ -496,12 +460,9 @@ def _retrieve_speaker_embeddings(self, feats, index, big_npy, index_rate): weight /= weight.sum(axis=1, keepdims=True) npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1) npy = npy.astype("float16") if self.is_half else npy - feats = ( - torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate - + (1 - index_rate) * feats - ) + feats = torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats return feats - + def pipeline( self, model, @@ -689,7 +650,9 @@ def pipeline( audio_max = np.abs(audio_opt).max() / 0.99 if audio_max > 1: audio_opt /= audio_max - del pitch, pitchf, sid + if pitch_guidance: + del pitch, pitchf + del sid if torch.cuda.is_available(): torch.cuda.empty_cache() return audio_opt diff --git a/rvc/train/data_utils.py b/rvc/train/data_utils.py index 2013558f..a4e7ef0f 100644 --- a/rvc/train/data_utils.py +++ b/rvc/train/data_utils.py @@ -6,7 +6,6 @@ from mel_processing import spectrogram_torch from utils import load_filepaths_and_text, load_wav_to_torch - class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset): """ Dataset that loads text and audio pairs. @@ -164,7 +163,6 @@ def __len__(self): """ return len(self.audiopaths_and_text) - class TextAudioCollateMultiNSFsid: """ Collates text and audio data for training. @@ -243,219 +241,6 @@ def __call__(self, batch): ) -class TextAudioLoader(torch.utils.data.Dataset): - """ - Dataset that loads text and audio pairs. - - Args: - hparams: Hyperparameters. - """ - - def __init__(self, hparams): - self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files) - self.max_wav_value = hparams.max_wav_value - self.sample_rate = hparams.sample_rate - self.filter_length = hparams.filter_length - self.hop_length = hparams.hop_length - self.win_length = hparams.win_length - self.sample_rate = hparams.sample_rate - self.min_text_len = getattr(hparams, "min_text_len", 1) - self.max_text_len = getattr(hparams, "max_text_len", 5000) - self._filter() - - def _filter(self): - """ - Filters audio paths and text pairs based on text length. - """ - audiopaths_and_text_new = [] - lengths = [] - for entry in self.audiopaths_and_text: - if len(entry) >= 3: - audiopath, text, dv = entry[:3] - if self.min_text_len <= len(text) and len(text) <= self.max_text_len: - audiopaths_and_text_new.append([audiopath, text, dv]) - lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length)) - - self.audiopaths_and_text = audiopaths_and_text_new - self.lengths = lengths - - def get_sid(self, sid): - """ - Converts speaker ID to a LongTensor. - - Args: - sid (str): Speaker ID. - """ - try: - sid = torch.LongTensor([int(sid)]) - except ValueError as error: - print(f"Error converting speaker ID '{sid}' to integer. Exception: {error}") - sid = torch.LongTensor([0]) - return sid - - def get_audio_text_pair(self, audiopath_and_text): - """ - Loads and processes audio and text data for a single pair. - - Args: - audiopath_and_text (list): List containing audio path, text, and speaker ID. - """ - file = audiopath_and_text[0] - phone = audiopath_and_text[1] - dv = audiopath_and_text[2] - - phone = self.get_labels(phone) - spec, wav = self.get_audio(file) - dv = self.get_sid(dv) - - len_phone = phone.size()[0] - len_spec = spec.size()[-1] - if len_phone != len_spec: - len_min = min(len_phone, len_spec) - len_wav = len_min * self.hop_length - spec = spec[:, :len_min] - wav = wav[:, :len_wav] - phone = phone[:len_min, :] - return (spec, wav, phone, dv) - - def get_labels(self, phone): - """ - Loads and processes phoneme labels. - - Args: - phone (str): Path to phoneme label file. - """ - phone = np.load(phone) - phone = np.repeat(phone, 2, axis=0) - n_num = min(phone.shape[0], 900) - phone = phone[:n_num, :] - phone = torch.FloatTensor(phone) - return phone - - def get_audio(self, filename): - """ - Loads and processes audio data. - - Args: - filename (str): Path to audio file. - """ - audio, sample_rate = load_wav_to_torch(filename) - if sample_rate != self.sample_rate: - raise ValueError( - f"{sample_rate} SR doesn't match target {self.sample_rate} SR" - ) - audio_norm = audio - audio_norm = audio_norm.unsqueeze(0) - spec_filename = filename.replace(".wav", ".spec.pt") - if os.path.exists(spec_filename): - try: - spec = torch.load(spec_filename) - except Exception as error: - print(f"An error occurred getting spec from {spec_filename}: {error}") - spec = spectrogram_torch( - audio_norm, - self.filter_length, - self.hop_length, - self.win_length, - center=False, - ) - spec = torch.squeeze(spec, 0) - torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) - else: - spec = spectrogram_torch( - audio_norm, - self.filter_length, - self.hop_length, - self.win_length, - center=False, - ) - spec = torch.squeeze(spec, 0) - torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) - return spec, audio_norm - - def __getitem__(self, index): - """ - Returns a single audio-text pair. - - Args: - index (int): Index of the data sample. - """ - return self.get_audio_text_pair(self.audiopaths_and_text[index]) - - def __len__(self): - """ - Returns the length of the dataset. - """ - return len(self.audiopaths_and_text) - - -class TextAudioCollate: - """ - Collates text and audio data for training. - - Args: - return_ids (bool, optional): Whether to return sample IDs. Defaults to False. - """ - - def __init__(self, return_ids=False): - self.return_ids = return_ids - - def __call__(self, batch): - """ - Collates a batch of data samples. - - Args: - batch (list): List of data samples. - """ - _, ids_sorted_decreasing = torch.sort( - torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True - ) - - max_spec_len = max([x[0].size(1) for x in batch]) - max_wave_len = max([x[1].size(1) for x in batch]) - spec_lengths = torch.LongTensor(len(batch)) - wave_lengths = torch.LongTensor(len(batch)) - spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len) - wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len) - spec_padded.zero_() - wave_padded.zero_() - - max_phone_len = max([x[2].size(0) for x in batch]) - phone_lengths = torch.LongTensor(len(batch)) - phone_padded = torch.FloatTensor( - len(batch), max_phone_len, batch[0][2].shape[1] - ) - phone_padded.zero_() - sid = torch.LongTensor(len(batch)) - - for i in range(len(ids_sorted_decreasing)): - row = batch[ids_sorted_decreasing[i]] - - spec = row[0] - spec_padded[i, :, : spec.size(1)] = spec - spec_lengths[i] = spec.size(1) - - wave = row[1] - wave_padded[i, :, : wave.size(1)] = wave - wave_lengths[i] = wave.size(1) - - phone = row[2] - phone_padded[i, : phone.size(0), :] = phone - phone_lengths[i] = phone.size(0) - - sid[i] = row[3] - - return ( - phone_padded, - phone_lengths, - spec_padded, - spec_lengths, - wave_padded, - wave_lengths, - sid, - ) - - class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): """ Distributed sampler that groups data into buckets based on length. diff --git a/rvc/train/train.py b/rvc/train/train.py index 014abbb5..4054cad2 100644 --- a/rvc/train/train.py +++ b/rvc/train/train.py @@ -36,9 +36,7 @@ from data_utils import ( DistributedBucketSampler, - TextAudioCollate, TextAudioCollateMultiNSFsid, - TextAudioLoader, TextAudioLoaderMultiNSFsid, ) @@ -74,7 +72,6 @@ overtraining_detector = strtobool(sys.argv[14]) overtraining_threshold = int(sys.argv[15]) sync_graph = strtobool(sys.argv[16]) -use_cpu = strtobool(sys.argv[17]) current_dir = os.getcwd() experiment_dir = os.path.join(current_dir, "logs", model_name) @@ -86,10 +83,6 @@ config = HParams(**config) config.data.training_files = os.path.join(experiment_dir, "filelist.txt") -if not use_cpu: - os.environ["CUDA_VISIBLE_DEVICES"] = gpus.replace("-", ",") -n_gpus = len(gpus.split("-")) if not use_cpu else 1 - torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False @@ -102,7 +95,6 @@ smoothed_loss_disc_history = [] lowest_value = {"step": 0, "value": float("inf"), "epoch": 0} training_file_path = os.path.join(experiment_dir, "training_data.json") -overtrain_info = None import logging @@ -156,37 +148,25 @@ def main(): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) # Check sample rate - first_wav_file = next( - ( - filename - for filename in os.listdir(os.path.join(experiment_dir, "sliced_audios")) - if filename.endswith(".wav") - ), - None, - ) - if first_wav_file: - audio = os.path.join(experiment_dir, "sliced_audios", first_wav_file) - _, sr = load_wav_to_torch(audio) + wavs = glob.glob(os.path.join(os.path.join(experiment_dir, "sliced_audios"), "*.wav")) + if wavs: + _, sr = load_wav_to_torch(wavs[0]) if sr != sample_rate: - try: - raise ValueError( - f"Error: Pretrained model sample rate ({sample_rate} Hz) does not match dataset audio sample rate ({sr} Hz)." - ) - except ValueError as e: - print( - f"Error: Pretrained model sample rate ({sample_rate} Hz) does not match dataset audio sample rate ({sr} Hz)." - ) - sys.exit(1) + print(f"Error: Pretrained model sample rate ({sample_rate} Hz) does not match dataset audio sample rate ({sr} Hz).") + os._exit(1) else: print("No wav file found.") - use_gpu = torch.cuda.is_available() and not use_cpu - device = torch.device("cuda" if use_gpu else "cpu") - - if use_gpu: + if torch.cuda.is_available(): + device = torch.device("cuda") n_gpus = torch.cuda.device_count() + elif torch.backends.mps.is_available(): + device = torch.device("mps") + n_gpus = 1 else: + device = torch.device("cpu") n_gpus = 1 + print("Training with CPU, this will take a long time.") def start(): """ @@ -259,24 +239,11 @@ def continue_overtrain_detector(training_file_path): smoothed_loss_gen_history, ) = load_from_json(training_file_path) - if use_cpu: - n_gpus = 1 - print("Training with CPU, this will take a long time.") - else: - n_gpus = torch.cuda.device_count() - - if torch.cuda.is_available() == False and torch.backends.mps.is_available() == True: - n_gpus = 1 - if n_gpus < 1 and not use_cpu: - print("GPU not detected, reverting to CPU (not recommended)") - n_gpus = 1 - - if sync_graph == True: - print( - "Sync graph is now activated! With sync graph enabled, the model undergoes a single epoch of training. Once the graphs are synchronized, training proceeds for the previously specified number of epochs." - ) + if sync_graph: + print("Sync graph is now activated! With sync graph enabled, the model undergoes a single epoch of training. Once the graphs are synchronized, training proceeds for the previously specified number of epochs.") custom_total_epoch = 1 custom_save_every_weights = True + start() # Synchronize graphs by modifying config files @@ -321,19 +288,14 @@ def edit_config(config_file): edit_config(rvc_config_file) # Clean up unnecessary files - for root, dirs, files in os.walk( - os.path.join(now_dir, "logs", model_name), topdown=False - ): + for root, dirs, files in os.walk(os.path.join(now_dir, "logs", model_name), topdown=False): for name in files: file_path = os.path.join(root, name) file_name, file_extension = os.path.splitext(name) - if file_extension == ".0": - os.remove(file_path) - elif ("D" in name or "G" in name) and file_extension == ".pth": - os.remove(file_path) - elif ( - "added" in name or "trained" in name - ) and file_extension == ".index": + if (file_extension == ".0" or + (file_name.startswith("D_") and file_extension == ".pth") or + (file_name.startswith("G_") and file_extension == ".pth") or + (file_name.startswith("added") and file_extension == ".index")): os.remove(file_path) for name in dirs: if name == "eval": @@ -355,7 +317,6 @@ def edit_config(config_file): continue_overtrain_detector(training_file_path) start() - def run( rank, n_gpus, @@ -403,23 +364,16 @@ def run( if torch.cuda.is_available(): torch.cuda.set_device(rank) - - # Zluda - if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"): - print("Disabling CUDNN for traning with Zluda") - torch.backends.cudnn.enabled = False - torch.backends.cuda.enable_flash_sdp(False) - torch.backends.cuda.enable_math_sdp(True) - torch.backends.cuda.enable_mem_efficient_sdp(False) + if torch.cuda.get_device_name().endswith("[ZLUDA]"): + print("Disabling CUDNN for traning with Zluda") + torch.backends.cudnn.enabled = False + torch.backends.cuda.enable_flash_sdp(False) + torch.backends.cuda.enable_math_sdp(True) + torch.backends.cuda.enable_mem_efficient_sdp(False) # Create datasets and dataloaders - if pitch_guidance == True: - train_dataset = TextAudioLoaderMultiNSFsid(config.data) - elif pitch_guidance == False: - train_dataset = TextAudioLoader(config.data) - else: - raise ValueError(f"Unexpected value for pitch_guidance: {pitch_guidance}") - + train_dataset = TextAudioLoaderMultiNSFsid(config.data) + collate_fn = TextAudioCollateMultiNSFsid() train_sampler = DistributedBucketSampler( train_dataset, batch_size * n_gpus, @@ -429,11 +383,6 @@ def run( shuffle=True, ) - if pitch_guidance == True: - collate_fn = TextAudioCollateMultiNSFsid() - elif pitch_guidance == False: - collate_fn = TextAudioCollate() - train_loader = DataLoader( train_dataset, num_workers=4, @@ -450,49 +399,29 @@ def run( config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model, - use_f0=pitch_guidance == True, + use_f0=pitch_guidance, is_half=config.train.fp16_run and device.type == "cuda", sr=sample_rate, - ) - - net_g = net_g.to(device) + ).to(device) if version == "v1": - net_d = MultiPeriodDiscriminator(config.model.use_spectral_norm) + net_d = MultiPeriodDiscriminator(config.model.use_spectral_norm).to(device) else: - net_d = MultiPeriodDiscriminatorV2(config.model.use_spectral_norm) + net_d = MultiPeriodDiscriminatorV2(config.model.use_spectral_norm).to(device) - net_d = net_d.to(device) - - optim_g = torch.optim.AdamW( - net_g.parameters(), - config.train.learning_rate, - betas=config.train.betas, - eps=config.train.eps, - ) - optim_d = torch.optim.AdamW( - net_d.parameters(), - config.train.learning_rate, - betas=config.train.betas, - eps=config.train.eps, - ) + optim_g = torch.optim.AdamW(net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps) + optim_d = torch.optim.AdamW(net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps) # Wrap models with DDP for multi-gpu processing if n_gpus > 1 and device.type == "cuda": net_g = DDP(net_g, device_ids=[rank]) net_d = DDP(net_d, device_ids=[rank]) - # else: - # net_g = DDP(net_g) - # net_d = DDP(net_d) + # Load checkpoint if available try: print("Starting training...") - _, _, _, epoch_str = load_checkpoint( - latest_checkpoint_path(experiment_dir, "D_*.pth"), net_d, optim_d - ) - _, _, _, epoch_str = load_checkpoint( - latest_checkpoint_path(experiment_dir, "G_*.pth"), net_g, optim_g - ) + _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "D_*.pth"), net_d, optim_d) + _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "G_*.pth"), net_g, optim_g) epoch_str += 1 global_step = (epoch_str - 1) * len(train_loader) @@ -504,40 +433,26 @@ def run( verify_checkpoint_shapes(pretrainG, net_g) print(f"Loaded pretrained (G) '{pretrainG}'") if hasattr(net_g, "module"): - net_g.module.load_state_dict( - torch.load(pretrainG, map_location="cpu")["model"] - ) - + net_g.module.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"]) else: - net_g.load_state_dict( - torch.load(pretrainG, map_location="cpu")["model"] - ) + net_g.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"]) if pretrainD != "": if rank == 0: print(f"Loaded pretrained (D) '{pretrainD}'") if hasattr(net_d, "module"): - net_d.module.load_state_dict( - torch.load(pretrainD, map_location="cpu")["model"] - ) - + net_d.module.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"]) else: - net_d.load_state_dict( - torch.load(pretrainD, map_location="cpu")["model"] - ) + net_d.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"]) # Initialize schedulers and scaler - scheduler_g = torch.optim.lr_scheduler.ExponentialLR( - optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2 - ) - scheduler_d = torch.optim.lr_scheduler.ExponentialLR( - optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2 - ) + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2) - optim_d.step() optim_g.step() + optim_d.step() - scaler = GradScaler(enabled=config.train.fp16_run and not use_cpu) + scaler = GradScaler(enabled=config.train.fp16_run and device.type == "cuda") cache = [] for epoch in range(epoch_str, total_epoch + 1): @@ -554,7 +469,7 @@ def run( cache, custom_save_every_weights, custom_total_epoch, - use_cpu, + device, ) else: train_and_evaluate( @@ -569,12 +484,11 @@ def run( cache, custom_save_every_weights, custom_total_epoch, - use_cpu, + device, ) scheduler_g.step() scheduler_d.step() - def train_and_evaluate( rank, epoch, @@ -587,7 +501,7 @@ def train_and_evaluate( cache, custom_save_every_weights, custom_total_epoch, - use_cpu, + device, ): """ Trains and evaluates the model for one epoch. @@ -624,75 +538,27 @@ def train_and_evaluate( net_d.train() # Data caching - if cache_data_in_gpu and not use_cpu: + if device.type == "cuda" and cache_data_in_gpu: data_iterator = cache if cache == []: for batch_idx, info in enumerate(train_loader): - if pitch_guidance == True: + phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid = info + cache.append( ( - phone, - phone_lengths, - pitch, - pitchf, - spec, - spec_lengths, - wave, - wave_lengths, - sid, - ) = info - elif pitch_guidance == False: - ( - phone, - phone_lengths, - spec, - spec_lengths, - wave, - wave_lengths, - sid, - ) = info - if torch.cuda.is_available(): - phone = phone.cuda(rank, non_blocking=True) - phone_lengths = phone_lengths.cuda(rank, non_blocking=True) - if pitch_guidance == True: - pitch = pitch.cuda(rank, non_blocking=True) - pitchf = pitchf.cuda(rank, non_blocking=True) - sid = sid.cuda(rank, non_blocking=True) - spec = spec.cuda(rank, non_blocking=True) - spec_lengths = spec_lengths.cuda(rank, non_blocking=True) - wave = wave.cuda(rank, non_blocking=True) - wave_lengths = wave_lengths.cuda(rank, non_blocking=True) - if pitch_guidance == True: - cache.append( + batch_idx, ( - batch_idx, - ( - phone, - phone_lengths, - pitch, - pitchf, - spec, - spec_lengths, - wave, - wave_lengths, - sid, - ), - ) - ) - elif pitch_guidance == False: - cache.append( - ( - batch_idx, - ( - phone, - phone_lengths, - spec, - spec_lengths, - wave, - wave_lengths, - sid, - ), + phone.cuda(rank, non_blocking=True), + phone_lengths.cuda(rank, non_blocking=True), + pitch.cuda(rank, non_blocking=True) if pitch_guidance else None, + pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None, + spec.cuda(rank, non_blocking=True), + spec_lengths.cuda(rank, non_blocking=True), + wave.cuda(rank, non_blocking=True), + wave_lengths.cuda(rank, non_blocking=True), + sid.cuda(rank, non_blocking=True) ) ) + ) else: shuffle(cache) else: @@ -701,81 +567,41 @@ def train_and_evaluate( epoch_recorder = EpochRecorder() with tqdm(total=len(train_loader), leave=False) as pbar: for batch_idx, info in data_iterator: - if pitch_guidance == True: - ( - phone, - phone_lengths, - pitch, - pitchf, - spec, - spec_lengths, - wave, - wave_lengths, - sid, - ) = info - elif pitch_guidance == False: - phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info - if ( - (cache_data_in_gpu == False) - and not use_cpu - and torch.cuda.is_available() - ): + phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, wave_lengths, sid = info + if device.type == "cuda" and not cache_data_in_gpu: phone = phone.cuda(rank, non_blocking=True) phone_lengths = phone_lengths.cuda(rank, non_blocking=True) - if pitch_guidance == True: - pitch = pitch.cuda(rank, non_blocking=True) - pitchf = pitchf.cuda(rank, non_blocking=True) + pitch = pitch.cuda(rank, non_blocking=True) if pitch_guidance else None + pitchf = pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None sid = sid.cuda(rank, non_blocking=True) spec = spec.cuda(rank, non_blocking=True) spec_lengths = spec_lengths.cuda(rank, non_blocking=True) wave = wave.cuda(rank, non_blocking=True) wave_lengths = wave_lengths.cuda(rank, non_blocking=True) - elif use_cpu: - phone = phone.cpu() - phone_lengths = phone_lengths.cpu() - if pitch_guidance == True: - pitch = pitch.cpu() - pitchf = pitchf.cpu() - sid = sid.cpu() - spec = spec.cpu() - spec_lengths = spec_lengths.cpu() - wave = wave.cpu() - wave_lengths = wave_lengths.cpu() + else: + phone = phone.to(device) + phone_lengths = phone_lengths.to(device) + pitch = pitch.to(device) if pitch_guidance else None + pitchf = pitchf.to(device) if pitch_guidance else None + sid = sid.to(device) + spec = spec.to(device) + spec_lengths = spec_lengths.to(device) + wave = wave.to(device) + wave_lengths = wave_lengths.to(device) # Forward pass - with autocast(enabled=config.train.fp16_run and not use_cpu): - if pitch_guidance == True: - ( - y_hat, - ids_slice, - x_mask, - z_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - ) = net_g( - phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid - ) - elif pitch_guidance == False: - ( - y_hat, - ids_slice, - x_mask, - z_mask, - (z, z_p, m_p, logs_p, m_q, logs_q), - ) = net_g(phone, phone_lengths, spec, spec_lengths, sid) + use_amp = config.train.fp16_run and device.type == "cuda" + with autocast(enabled=use_amp): + (y_hat, ids_slice, x_mask, z_mask, (z, z_p, m_p, logs_p, m_q, logs_q)) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid) mel = spec_to_mel_torch( spec, config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.mel_fmin, - config.data.mel_fmax, - ) - y_mel = commons.slice_segments( - mel, - ids_slice, - config.train.segment_size // config.data.hop_length, - dim=3, + config.data.mel_fmax ) + y_mel = commons.slice_segments(mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3) with autocast(enabled=False): y_hat_mel = mel_spectrogram_torch( y_hat.float().squeeze(1), @@ -785,23 +611,14 @@ def train_and_evaluate( config.data.hop_length, config.data.win_length, config.data.mel_fmin, - config.data.mel_fmax, + config.data.mel_fmax ) - if config.train.fp16_run == True and not use_cpu: + if use_amp: y_hat_mel = y_hat_mel.half() - wave = commons.slice_segments( - wave, - ids_slice * config.data.hop_length, - config.train.segment_size, - dim=3, - ) - + wave = commons.slice_segments(wave, ids_slice * config.data.hop_length, config.train.segment_size, dim=3) y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach()) with autocast(enabled=False): - loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( - y_d_hat_r, y_d_hat_g - ) - + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) # Discriminator backward and update optim_d.zero_grad() scaler.scale(loss_disc).backward() @@ -810,13 +627,11 @@ def train_and_evaluate( scaler.step(optim_d) # Generator backward and update - with autocast(enabled=config.train.fp16_run and not use_cpu): + with autocast(enabled=use_amp): y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat) with autocast(enabled=False): loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel - loss_kl = ( - kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl - ) + loss_kl = (kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl) loss_fm = feature_loss(fmap_r, fmap_g) loss_gen, losses_gen = generator_loss(y_d_hat_g) loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl @@ -827,9 +642,7 @@ def train_and_evaluate( lowest_value["epoch"] = epoch # print(f'Lowest generator loss updated: {lowest_value["value"]} at epoch {epoch}, step {global_step}') if epoch > lowest_value["epoch"]: - print( - "Alert: The lower generating loss has been exceeded by a lower loss in a subsequent epoch." - ) + print("Alert: The lower generating loss has been exceeded by a lower loss in a subsequent epoch.") optim_g.zero_grad() scaler.scale(loss_gen_all).backward() @@ -842,12 +655,10 @@ def train_and_evaluate( if rank == 0: if global_step % config.train.log_interval == 0: lr = optim_g.param_groups[0]["lr"] - if loss_mel > 75: loss_mel = 75 if loss_kl > 9: loss_kl = 9 - scalar_dict = { "loss/g/total": loss_gen_all, "loss/d/total": loss_disc, @@ -855,29 +666,13 @@ def train_and_evaluate( "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g, } - scalar_dict.update( - { - "loss/g/fm": loss_fm, - "loss/g/mel": loss_mel, - "loss/g/kl": loss_kl, - } - ) - scalar_dict.update( - {f"loss/g/{i}": v for i, v in enumerate(losses_gen)} - ) - scalar_dict.update( - {f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)} - ) - scalar_dict.update( - {f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)} - ) + scalar_dict.update({"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl}) + scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)}) + scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)}) + scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)}) image_dict = { - "slice/mel_org": plot_spectrogram_to_numpy( - y_mel[0].data.cpu().numpy() - ), - "slice/mel_gen": plot_spectrogram_to_numpy( - y_hat_mel[0].data.cpu().numpy() - ), + "slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), + "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), } summarize( @@ -891,265 +686,166 @@ def train_and_evaluate( pbar.update(1) # Save checkpoint - if epoch % save_every_epoch == False and rank == 0: - checkpoint_suffix = ( - f"{global_step if save_only_latest == False else 2333333}.pth" - ) - save_checkpoint( - net_g, - optim_g, - config.train.learning_rate, - epoch, - os.path.join(experiment_dir, "G_" + checkpoint_suffix), - ) - save_checkpoint( - net_d, - optim_d, - config.train.learning_rate, - epoch, - os.path.join(experiment_dir, "D_" + checkpoint_suffix), - ) - if rank == 0 and custom_save_every_weights == True: - if hasattr(net_g, "module"): - ckpt = net_g.module.state_dict() + model_add = [] + model_del = [] + done = False + + if rank == 0: + # Save weights every N epochs + if epoch % save_every_epoch == 0: + checkpoint_suffix = (f"{2333333 if save_only_latest else global_step}.pth") + save_checkpoint(net_g, optim_g, config.train.learning_rate, epoch, os.path.join(experiment_dir, "G_" + checkpoint_suffix)) + save_checkpoint(net_d, optim_d, config.train.learning_rate, epoch, os.path.join(experiment_dir, "D_" + checkpoint_suffix)) + if custom_save_every_weights: + model_add.append(os.path.join(experiment_dir, f"{model_name}_{epoch}e_{global_step}s.pth")) + overtrain_info = "" + # Check overtraining + if overtraining_detector and rank == 0 and epoch > 1: + # Add the current loss to the history + current_loss_disc = float(loss_disc) + loss_disc_history.append(current_loss_disc) + # Update smoothed loss history with loss_disc + smoothed_value_disc = update_exponential_moving_average(smoothed_loss_disc_history, current_loss_disc) + # Check overtraining with smoothed loss_disc + is_overtraining_disc = check_overtraining(smoothed_loss_disc_history, overtraining_threshold * 2) + if is_overtraining_disc: + consecutive_increases_disc += 1 else: - ckpt = net_g.state_dict() - if overtraining_detector and epoch > 1: - overtrain_info = f"Smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}" + consecutive_increases_disc = 0 + # Add the current loss_gen to the history + current_loss_gen = float(lowest_value["value"]) + loss_gen_history.append(current_loss_gen) + # Update the smoothed loss_gen history + smoothed_value_gen = update_exponential_moving_average(smoothed_loss_gen_history, current_loss_gen) + # Check for overtraining with the smoothed loss_gen + is_overtraining_gen = check_overtraining(smoothed_loss_gen_history, overtraining_threshold, 0.01) + if is_overtraining_gen: + consecutive_increases_gen += 1 else: - overtrain_info = "" - extract_model( - ckpt=ckpt, - sr=sample_rate, - pitch_guidance=pitch_guidance == True, - name=model_name, - model_dir=os.path.join( - experiment_dir, - f"{model_name}_{epoch}e_{global_step}s.pth", - ), - epoch=epoch, - step=global_step, - version=version, - hps=hps, - overtrain_info=overtrain_info, - ) - - def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004): - """ - Checks for overtraining based on the smoothed loss history. - - Args: - smoothed_loss_history (list): List of smoothed losses for each epoch. - threshold (int): Number of consecutive epochs with insignificant changes or increases to consider overtraining. - epsilon (float): The maximum change considered insignificant. - """ - if len(smoothed_loss_history) < threshold + 1: - return False - - for i in range(-threshold, -1): - if smoothed_loss_history[i + 1] > smoothed_loss_history[i]: - return True - if abs(smoothed_loss_history[i + 1] - smoothed_loss_history[i]) >= epsilon: - return False - - return True - - def update_exponential_moving_average( - smoothed_loss_history, new_value, smoothing=0.987 - ): - """ - Updates the exponential moving average with a new value. - - Args: - smoothed_loss_history (list): List of smoothed values. - new_value (float): New value to be added. - smoothing (float): Smoothing factor. - """ - if not smoothed_loss_history: - smoothed_value = new_value - else: - smoothed_value = ( - smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value - ) - smoothed_loss_history.append(smoothed_value) - return smoothed_value - - def save_to_json( - file_path, - loss_disc_history, - smoothed_loss_disc_history, - loss_gen_history, - smoothed_loss_gen_history, - ): - """ - Save the training history to a JSON file. - """ - data = { - "loss_disc_history": loss_disc_history, - "smoothed_loss_disc_history": smoothed_loss_disc_history, - "loss_gen_history": loss_gen_history, - "smoothed_loss_gen_history": smoothed_loss_gen_history, - } - with open(file_path, "w") as f: - json.dump(data, f) - - if overtraining_detector and rank == 0 and epoch > 1: - # Add the current loss to the history - current_loss_disc = float(loss_disc) - loss_disc_history.append(current_loss_disc) - - # Update smoothed loss history with loss_disc - smoothed_value_disc = update_exponential_moving_average( - smoothed_loss_disc_history, current_loss_disc - ) - - # Check overtraining with smoothed loss_disc - is_overtraining_disc = check_overtraining( - smoothed_loss_disc_history, overtraining_threshold * 2 - ) - if is_overtraining_disc: - consecutive_increases_disc += 1 - else: - consecutive_increases_disc = 0 - # Add the current loss_gen to the history - current_loss_gen = float(lowest_value["value"]) - loss_gen_history.append(current_loss_gen) - - # Update the smoothed loss_gen history - smoothed_value_gen = update_exponential_moving_average( - smoothed_loss_gen_history, current_loss_gen - ) - - # Check for overtraining with the smoothed loss_gen - is_overtraining_gen = check_overtraining( - smoothed_loss_gen_history, overtraining_threshold, 0.01 - ) - if is_overtraining_gen: - consecutive_increases_gen += 1 - else: - consecutive_increases_gen = 0 - - overtrain_info = f"Smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}" - # Save the data in the JSON file if the epoch is divisible by save_every_epoch - if epoch % save_every_epoch == 0: - save_to_json( - training_file_path, - loss_disc_history, - smoothed_loss_disc_history, - loss_gen_history, - smoothed_loss_gen_history, - ) - - if ( - is_overtraining_gen - and consecutive_increases_gen == overtraining_threshold - or is_overtraining_disc - and consecutive_increases_disc == (overtraining_threshold * 2) - ): - print( - f"Overtraining detected at epoch {epoch} with smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}" - ) - os._exit(2333333) - else: - print( - f"New best epoch {epoch} with smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}" - ) - old_model_files = glob.glob( - os.path.join(experiment_dir, f"{model_name}_*e_*s_best_epoch.pth") - ) - for file in old_model_files: - os.remove(file) - - if hasattr(net_g, "module"): - ckpt = net_g.module.state_dict() + consecutive_increases_gen = 0 + overtrain_info = f"Smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}" + # Save the data in the JSON file if the epoch is divisible by save_every_epoch + if epoch % save_every_epoch == 0: + save_to_json(training_file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history) + + if is_overtraining_gen and consecutive_increases_gen == overtraining_threshold or is_overtraining_disc and consecutive_increases_disc == overtraining_threshold * 2: + print(f"Overtraining detected at epoch {epoch} with smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}") + done = True else: - ckpt = net_g.state_dict() - if overtraining_detector != True: - overtrain_info = None - extract_model( - ckpt=ckpt, - sr=sample_rate, - pitch_guidance=pitch_guidance == True, - name=model_name, - model_dir=os.path.join( - experiment_dir, - f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth", - ), - epoch=epoch, - step=global_step, - version=version, - hps=hps, - overtrain_info=overtrain_info, - ) + print(f"New best epoch {epoch} with smoothed loss_g {smoothed_value_gen:.3f} and loss_d {smoothed_value_disc:.3f}") + old_model_files = glob.glob(os.path.join(experiment_dir, f"{model_name}_*e_*s_best_epoch.pth")) + for file in old_model_files: + model_del.append(file) + model_add.append(os.path.join(experiment_dir, f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth")) + + # Check completion + if epoch >= custom_total_epoch: + lowest_value_rounded = float(lowest_value["value"]) + lowest_value_rounded = round(lowest_value_rounded, 3) + print(f"Training has been successfully completed with {epoch} epoch, {global_step} steps and {round(loss_gen_all.item(), 3)} loss gen.") + print(f"Lowest generator loss: {lowest_value_rounded} at epoch {lowest_value['epoch']}, step {lowest_value['step']}") + + pid_file_path = os.path.join(experiment_dir, "config.json") + with open(pid_file_path, "r") as pid_file: + pid_data = json.load(pid_file) + with open(pid_file_path, "w") as pid_file: + pid_data.pop("process_pids", None) + json.dump(pid_data, pid_file, indent=4) + # Final model + model_add.append(os.path.join(experiment_dir, f"{model_name}_{epoch}e_{global_step}s.pth")) + done = True + + if model_add: + ckpt = net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict() + for m in model_add: + if not os.path.exists(m): + extract_model( + ckpt=ckpt, + sr=sample_rate, + pitch_guidance=pitch_guidance, + name=model_name, + model_dir=m, + epoch=epoch, + step=global_step, + version=version, + hps=hps, + overtrain_info=overtrain_info + ) + # Clean-up old best epochs + for m in model_del: + os.remove(m) - # Print training progress - if rank == 0: + # Print training progress lowest_value_rounded = float(lowest_value["value"]) lowest_value_rounded = round(lowest_value_rounded, 3) - if epoch > 1 and overtraining_detector == True: + record = f"{model_name} | epoch={epoch} | step={global_step} | {epoch_recorder.record()}" + if epoch > 1: + record = record + f" | lowest_value={lowest_value_rounded} (epoch {lowest_value['epoch']} and step {lowest_value['step']})" + + if overtraining_detector: remaining_epochs_gen = overtraining_threshold - consecutive_increases_gen - remaining_epochs_disc = ( - overtraining_threshold * 2 - ) - consecutive_increases_disc - print( - f"{model_name} | epoch={epoch} | step={global_step} | {epoch_recorder.record()} | lowest_value={lowest_value_rounded} (epoch {lowest_value['epoch']} and step {lowest_value['step']}) | Number of epochs remaining for overtraining: g/total: {remaining_epochs_gen} d/total: {remaining_epochs_disc} | smoothed_loss_gen={smoothed_value_gen:.3f} | smoothed_loss_disc={smoothed_value_disc:.3f}" - ) - elif epoch > 1 and overtraining_detector == False: - print( - f"{model_name} | epoch={epoch} | step={global_step} | {epoch_recorder.record()} | lowest_value={lowest_value_rounded} (epoch {lowest_value['epoch']} and step {lowest_value['step']})" - ) - else: - print( - f"{model_name} | epoch={epoch} | step={global_step} | {epoch_recorder.record()}" - ) + remaining_epochs_disc = overtraining_threshold * 2 - consecutive_increases_disc + record = record + f" | Number of epochs remaining for overtraining: g/total: {remaining_epochs_gen} d/total: {remaining_epochs_disc} | smoothed_loss_gen={smoothed_value_gen:.3f} | smoothed_loss_disc={smoothed_value_disc:.3f}" + print(record) last_loss_gen_all = loss_gen_all - # Save the final model - if epoch >= custom_total_epoch and rank == 0: - lowest_value_rounded = float(lowest_value["value"]) - lowest_value_rounded = round(lowest_value_rounded, 3) - print( - f"Training has been successfully completed with {epoch} epoch, {global_step} steps and {round(loss_gen_all.item(), 3)} loss gen." - ) - print( - f"Lowest generator loss: {lowest_value_rounded} at epoch {lowest_value['epoch']}, step {lowest_value['step']}" - ) + if done: + os._exit(2333333) - pid_file_path = os.path.join(experiment_dir, "config.json") - with open(pid_file_path, "r") as pid_file: - pid_data = json.load(pid_file) - with open(pid_file_path, "w") as pid_file: - pid_data.pop("process_pids", None) - json.dump(pid_data, pid_file, indent=4) +def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004): + """ + Checks for overtraining based on the smoothed loss history. - if not os.path.exists( - os.path.join(experiment_dir, f"{model_name}_{epoch}e_{global_step}s.pth") - ): - if hasattr(net_g, "module"): - ckpt = net_g.module.state_dict() - else: - ckpt = net_g.state_dict() - if overtraining_detector != True: - overtrain_info = None - extract_model( - ckpt=ckpt, - sr=sample_rate, - pitch_guidance=pitch_guidance == True, - name=model_name, - model_dir=os.path.join( - experiment_dir, - f"{model_name}_{epoch}e_{global_step}s.pth", - ), - epoch=epoch, - step=global_step, - version=version, - hps=hps, - overtrain_info=overtrain_info, - ) - sleep(1) - os._exit(2333333) + Args: + smoothed_loss_history (list): List of smoothed losses for each epoch. + threshold (int): Number of consecutive epochs with insignificant changes or increases to consider overtraining. + epsilon (float): The maximum change considered insignificant. + """ + if len(smoothed_loss_history) < threshold + 1: + return False + + for i in range(-threshold, -1): + if smoothed_loss_history[i + 1] > smoothed_loss_history[i]: + return True + if abs(smoothed_loss_history[i + 1] - smoothed_loss_history[i]) >= epsilon: + return False + return True +def update_exponential_moving_average(smoothed_loss_history, new_value, smoothing=0.987): + """ + Updates the exponential moving average with a new value. + + Args: + smoothed_loss_history (list): List of smoothed values. + new_value (float): New value to be added. + smoothing (float): Smoothing factor. + """ + if smoothed_loss_history: + smoothed_value = smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value + else: + smoothed_value = new_value + smoothed_loss_history.append(smoothed_value) + return smoothed_value + +def save_to_json( + file_path, + loss_disc_history, + smoothed_loss_disc_history, + loss_gen_history, + smoothed_loss_gen_history, + ): + """ + Save the training history to a JSON file. + """ + data = { + "loss_disc_history": loss_disc_history, + "smoothed_loss_disc_history": smoothed_loss_disc_history, + "loss_gen_history": loss_gen_history, + "smoothed_loss_gen_history": smoothed_loss_gen_history, + } + with open(file_path, "w") as f: + json.dump(data, f) if __name__ == "__main__": torch.multiprocessing.set_start_method("spawn") diff --git a/tabs/train/train.py b/tabs/train/train.py index bbd266ff..3e735b09 100644 --- a/tabs/train/train.py +++ b/tabs/train/train.py @@ -621,12 +621,6 @@ def train_tab(): value=True, interactive=True, ) - use_cpu = gr.Checkbox( - label=i18n("Use CPU"), - info=i18n("Force the use of CPU for training."), - value=False, - interactive=True, - ) with gr.Column(): sync_graph = gr.Checkbox( label=i18n("Sync Graph"), @@ -778,7 +772,6 @@ def train_tab(): index_algorithm, cache_dataset_in_gpu, custom_pretrained, - use_cpu, g_pretrained_path, d_pretrained_path, ],