diff --git a/library/train_util.py b/library/train_util.py index 3c850019e..e22afe1cb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2848,6 +2848,17 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", ) + parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う") + parser.add_argument( + "--dynamo_backend", + type=str, + default="inductor", + # available backends: + # https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5 + # https://pytorch.org/docs/stable/torch.compiler.html + choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"], + help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)" + ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( "--sdpa", @@ -3869,6 +3880,11 @@ def prepare_accelerator(args: argparse.Namespace): os.environ["WANDB_DIR"] = logging_dir if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) + + # torch.compile のオプション。 NO の場合は torch.compile は使わない + dynamo_backend = "NO" + if args.torch_compile: + dynamo_backend = args.dynamo_backend kwargs_handlers = ( InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, @@ -3883,6 +3899,7 @@ def prepare_accelerator(args: argparse.Namespace): log_with=log_with, project_dir=logging_dir, kwargs_handlers=kwargs_handlers, + dynamo_backend=dynamo_backend, ) return accelerator diff --git a/requirements.txt b/requirements.txt index c27131cd7..0a80d70d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ diffusers[torch]==0.21.2 ftfy==6.1.1 # albumentations==1.3.0 opencv-python==4.7.0.68 -einops==0.6.0 +einops==0.6.1 pytorch-lightning==1.9.0 # bitsandbytes==0.39.1 tensorboard==2.10.1