From 1470b31b56429daa969e4f033f5bdb14da6022f5 Mon Sep 17 00:00:00 2001 From: jiaqiw Date: Tue, 10 Oct 2023 19:15:00 +0800 Subject: [PATCH 1/2] fix error reported 'find_unused_parameters' running in mutiple GPUs or NPUs --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 5 +++-- examples/text_to_image/train_text_to_image_lora_sdxl.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index ac59bba6c847..50dfa80f3602 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -32,7 +32,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed +from accelerate.utils import ProjectConfiguration, set_seed, DistributedDataParallelKwargs from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image @@ -595,12 +595,13 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, + kwargs_handlers=[kwargs], ) if args.report_to == "wandb": diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index ed7a15cd95fe..40cbdee5396a 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -33,7 +33,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed +from accelerate.utils import ProjectConfiguration, set_seed, DistributedDataParallelKwargs from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version @@ -491,12 +491,13 @@ def main(args): logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) - + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, + kwargs_handlers=[kwargs], ) if args.report_to == "wandb": From 5f2dfd36742a5a58fed52799cfdeeefbb46f1cdd Mon Sep 17 00:00:00 2001 From: Humphrey009 Date: Fri, 13 Oct 2023 21:40:41 +0800 Subject: [PATCH 2/2] fix code check of importing module by its alphabetic order --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- examples/text_to_image/train_text_to_image_lora_sdxl.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 50dfa80f3602..ced16af6c939 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -32,7 +32,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed, DistributedDataParallelKwargs +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 40cbdee5396a..0710de1b32c0 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -33,7 +33,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger -from accelerate.utils import ProjectConfiguration, set_seed, DistributedDataParallelKwargs +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version