From 20296b4f0e5b0037dcbca277fcd4660e81b74906 Mon Sep 17 00:00:00 2001 From: Plat Date: Wed, 27 Dec 2023 02:12:37 +0900 Subject: [PATCH 1/2] chore: bump eniops version due to support torch.compile --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c27131cd7..0a80d70d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ diffusers[torch]==0.21.2 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.7.0.68 -einops==0.6.0 +einops==0.6.1 pytorch-lightning==1.9.0 # bitsandbytes==0.39.1 tensorboard==2.10.1 From 62e7516537e98fa9ab473fa940197dae879156b1 Mon Sep 17 00:00:00 2001 From: Plat Date: Wed, 27 Dec 2023 02:13:37 +0900 Subject: [PATCH 2/2] feat: support torch.compile --- library/train_util.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 3c850019e..e22afe1cb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2848,6 +2848,17 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) + parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う") + parser.add_argument( + "--dynamo_backend", + type=str, + default="inductor", + # available backends: + # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 + # https://pytorch.org/docs/stable/torch.compiler.html + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)" + ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( "--sdpa", @@ -3869,6 +3880,11 @@ def prepare_accelerator(args: argparse.Namespace): os.environ["WANDB_DIR"] = logging_dir if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) + + # torch.compile のオプション。 NO の場合は torch.compile は使わない + dynamo_backend = "NO" + if args.torch_compile: + dynamo_backend = args.dynamo_backend kwargs_handlers = ( InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, @@ -3883,6 +3899,7 @@ def prepare_accelerator(args: argparse.Namespace): log_with=log_with, project_dir=logging_dir, kwargs_handlers=kwargs_handlers, + dynamo_backend=dynamo_backend, ) return accelerator