From 6539363c5c13a3e63fc0e52adf7fc26fb566d491 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Fri, 20 Oct 2023 09:31:43 +0800 Subject: [PATCH] add train_network --- .gitignore | 1 + train_network.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index e492b1add..19b6fd000 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ venv build .vscode wandb +.vs diff --git a/train_network.py b/train_network.py index 37fd2d4e8..9deb53313 100644 --- a/train_network.py +++ b/train_network.py @@ -813,6 +813,8 @@ def remove_model(old_ckpt_name): loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし