diff --git a/.gitignore b/.gitignore index e492b1add..19b6fd000 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ venv build .vscode wandb +.vs diff --git a/train_network.py b/train_network.py index 37fd2d4e8..9deb53313 100644 --- a/train_network.py +++ b/train_network.py @@ -813,6 +813,8 @@ def remove_model(old_ckpt_name): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし