Skip to content

Commit

Permalink
Fix math for soft min SNR gamma
Browse files Browse the repository at this point in the history
  • Loading branch information
rockerBOO committed Jan 31, 2024
1 parent 38ef8ea commit 7468655
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False

def apply_soft_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=False):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
soft_min_snr_gamma_weight = 1 / (torch.pow(snr if v_prediction is False else snr + 1, 2) + (1 / float(gamma)))
loss = loss * soft_min_snr_gamma_weight
snr_weight = (snr * gamma / (snr + gamma)).float().to(loss.device)
if v_prediction:
snr_weight /= snr + 1
else:
snr_weight /= snr
loss = loss * snr_weight
return loss


Expand Down

0 comments on commit 7468655

Please sign in to comment.