Skip to content

Commit

Permalink
fix: lora config saving with multiple gpus
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed May 18, 2024
1 parent f7f26a0 commit c98617d
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,17 @@ def save_networks(self, epoch):
net = getattr(self, "net" + name)

if len(self.gpu_ids) > 1 and self.use_cuda:
torch.save(net.module.state_dict(), save_path)
if (
name == "G_A"
and hasattr(net.module, "unet")
and hasattr(net.module, "vae")
and any(
"lora" in n for n, _ in net.module.unet.named_parameters()
)
):
net.module.save_lora_config(save_path)
else:
torch.save(net.module.state_dict(), save_path)
else:
if (
name == "G_A"
Expand Down

0 comments on commit c98617d

Please sign in to comment.