diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 6bd01a069..6cc391230 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -70,8 +70,12 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False): snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) - soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma))) - loss = loss * soft_min_snr_gamma_weight + snr_weight = (snr * gamma / (snr + gamma)).float().to(loss.device) + if v_prediction: + snr_weight /= snr + 1 + else: + snr_weight /= snr + loss = loss * snr_weight return loss