diff --git a/driver.py b/driver.py index 489c27f..dee78f2 100644 --- a/driver.py +++ b/driver.py @@ -128,7 +128,7 @@ def main(): with accelerator.accumulate(model): loss = diffusion(mask, img) accelerator.log({'loss': loss}) # Log loss to wandb - loss.backward() + accelerator.backward(loss) optimizer.step() optimizer.zero_grad() running_loss += loss.item() * img.size(0)