From 9973797f572d9de49a40c88805126ed3cfd3484b Mon Sep 17 00:00:00 2001 From: touchwolf Date: Sun, 4 Aug 2024 22:51:07 +0800 Subject: [PATCH 1/4] Improve config support for transformers with accelerate --- train.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 892a1240a..246009e9d 100644 --- a/train.py +++ b/train.py @@ -1942,7 +1942,14 @@ def main(): width=latents.shape[3], ) guidance_scale = 3 # >>> ????? <<< - if transformer.config.guidance_embeds: + original_config = transformer.config if hasattr(transformer, 'config') else None + if hasattr(transformer, 'module'): + transformer_config = transformer.module.config + elif hasattr(transformer, 'config'): + transformer_config = transformer.config + else: + transformer_config = original_config + if transformer_config and getattr(transformer_config, 'guidance_embeds', False): guidance = torch.tensor( [guidance_scale], device=accelerator.device ) From a100b00df7d740d9aba6c7f3c5c57ac5400184ac Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Sun, 4 Aug 2024 09:08:38 -0600 Subject: [PATCH 2/4] Update train.py --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 246009e9d..b66f9ea9a 100644 --- a/train.py +++ b/train.py @@ -1942,7 +1942,7 @@ def main(): width=latents.shape[3], ) guidance_scale = 3 # >>> ????? <<< - original_config = transformer.config if hasattr(transformer, 'config') else None + transformer_config = None if hasattr(transformer, 'module'): transformer_config = transformer.module.config elif hasattr(transformer, 'config'): From 4acc5762d0105411ed861a7ebd0f01e12a50a078 Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Sun, 4 Aug 2024 09:08:44 -0600 Subject: [PATCH 3/4] Update train.py --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index b66f9ea9a..ed8891bc1 100644 --- a/train.py +++ b/train.py @@ -1949,7 +1949,7 @@ def main(): transformer_config = transformer.config else: transformer_config = original_config - if transformer_config and getattr(transformer_config, 'guidance_embeds', False): + if transformer_config is not None and getattr(transformer_config, 'guidance_embeds', False): guidance = torch.tensor( [guidance_scale], device=accelerator.device ) From 15c8a2fcdc7c9f2c1c97ad9b4b232b2c90cf9797 Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Sun, 4 Aug 2024 09:08:49 -0600 Subject: [PATCH 4/4] Update train.py --- train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/train.py b/train.py index ed8891bc1..8122189d6 100644 --- a/train.py +++ b/train.py @@ -1947,8 +1947,6 @@ def main(): transformer_config = transformer.module.config elif hasattr(transformer, 'config'): transformer_config = transformer.config - else: - transformer_config = original_config if transformer_config is not None and getattr(transformer_config, 'guidance_embeds', False): guidance = torch.tensor( [guidance_scale], device=accelerator.device