diff --git a/train.py b/train.py index 892a1240..8122189d 100644 --- a/train.py +++ b/train.py @@ -1942,7 +1942,12 @@ def main(): width=latents.shape[3], ) guidance_scale = 3 # >>> ????? <<< - if transformer.config.guidance_embeds: + transformer_config = None + if hasattr(transformer, 'module'): + transformer_config = transformer.module.config + elif hasattr(transformer, 'config'): + transformer_config = transformer.config + if transformer_config is not None and getattr(transformer_config, 'guidance_embeds', False): guidance = torch.tensor( [guidance_scale], device=accelerator.device )