From c98617d1e534018d095c2c0ee96a3f8f9980fa8e Mon Sep 17 00:00:00 2001 From: beniz Date: Sat, 18 May 2024 11:10:41 +0200 Subject: [PATCH] fix: lora config saving with multiple gpus --- models/base_model.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/models/base_model.py b/models/base_model.py index 54b438253..7ff4605f9 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -801,7 +801,17 @@ def save_networks(self, epoch): net = getattr(self, "net" + name) if len(self.gpu_ids) > 1 and self.use_cuda: - torch.save(net.module.state_dict(), save_path) + if ( + name == "G_A" + and hasattr(net.module, "unet") + and hasattr(net.module, "vae") + and any( + "lora" in n for n, _ in net.module.unet.named_parameters() + ) + ): + net.module.save_lora_config(save_path) + else: + torch.save(net.module.state_dict(), save_path) else: if ( name == "G_A"