diff --git a/fine_tune.py b/fine_tune.py index 982dc8aec..11e94e560 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -2,7 +2,6 @@ # XXX dropped option: hypernetwork training import argparse -import gc import math import os from multiprocessing import Value @@ -11,6 +10,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -158,9 +158,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 074576bc2..524f80b9d 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -15,8 +15,9 @@ sys.path.append(os.path.dirname(__file__)) from blip.blip import blip_decoder, is_url import library.train_util as train_util +from library.device_utils import get_preferred_device -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = get_preferred_device() IMAGE_SIZE = 384 diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index b3c5cc423..2b650eb01 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -10,9 +10,9 @@ from transformers.generation.utils import GenerationMixin import library.train_util as train_util +from library.device_utils import get_preferred_device - -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = get_preferred_device() PATTERN_REPLACE = [ re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 1bccb1d3b..9d352dd6e 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -14,7 +14,9 @@ import library.model_util as model_util import library.train_util as train_util -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from library.device_utils import get_preferred_device + +DEVICE = get_preferred_device() IMAGE_TRANSFORMS = transforms.Compose( [ diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index a207ad5a1..38b1ceabf 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -66,6 +66,7 @@ import numpy as np import torch +from library.device_utils import clean_memory, get_preferred_device from library.ipex_interop import init_ipex init_ipex() @@ -888,8 +889,7 @@ def __call__( init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -1047,8 +1047,7 @@ def __call__( if vae_batch_size >= batch_size: image = self.vae.decode(latents).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( @@ -2325,7 +2324,7 @@ def __getattr__(self, item): scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: diff --git a/library/device_utils.py b/library/device_utils.py new file mode 100644 index 000000000..353bfa9f3 --- /dev/null +++ b/library/device_utils.py @@ -0,0 +1,34 @@ +import functools +import gc + +import torch + +try: + HAS_CUDA = torch.cuda.is_available() +except Exception: + HAS_CUDA = False + +try: + HAS_MPS = torch.backends.mps.is_available() +except Exception: + HAS_MPS = False + + +def clean_memory(): + gc.collect() + if HAS_CUDA: + torch.cuda.empty_cache() + if HAS_MPS: + torch.mps.empty_cache() + + +@functools.lru_cache(maxsize=None) +def get_preferred_device() -> torch.device: + if HAS_CUDA: + device = torch.device("cuda") + elif HAS_MPS: + device = torch.device("mps") + else: + device = torch.device("cpu") + print(f"get_preferred_device() -> {device}") + return device diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 5ad748d15..d2becad6f 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -1,5 +1,4 @@ import argparse -import gc import math import os from typing import Optional @@ -8,6 +7,7 @@ from tqdm import tqdm from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet +from library.device_utils import clean_memory from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline TOKENIZER1_PATH = "openai/clip-vit-large-patch14" @@ -47,8 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory() accelerator.wait_for_everyone() return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/train_util.py b/library/train_util.py index ba428e508..d59f42584 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -20,7 +20,6 @@ Union, ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs -import gc import glob import math import os @@ -67,6 +66,7 @@ # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork +from library.device_utils import clean_memory from library.original_unet import UNet2DConditionModel # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う @@ -2279,8 +2279,7 @@ def cache_batch_latents( info.latents_flipped = flipped_latent # FIXME this slows down caching a lot, specify this as an option - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() def cache_batch_text_encoder_outputs( @@ -3920,6 +3919,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers=kwargs_handlers, dynamo_backend=dynamo_backend, ) + print("accelerator device:", accelerator.device) return accelerator @@ -4006,8 +4006,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory() accelerator.wait_for_everyone() return text_encoder, vae, unet, load_stable_diffusion_format @@ -4816,7 +4815,7 @@ def sample_images_common( # clear pipeline and cache to reduce vram usage del pipeline - torch.cuda.empty_cache() + clean_memory() torch.set_rng_state(rng_state) if cuda_rng_state is not None: diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index 47d75ac4d..0056ac78e 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -11,6 +11,8 @@ from transformers import CLIPTextModel import torch +from library.device_utils import get_preferred_device + def make_unet_conversion_map() -> Dict[str, str]: unet_conversion_map_layer = [] @@ -476,7 +478,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline import torch - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = get_preferred_device() parser = argparse.ArgumentParser() parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface") diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 0dc066fd1..83942d7ce 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -9,11 +9,12 @@ import library.model_util as model_util import lora +from library.device_utils import get_preferred_device TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = get_preferred_device() def interrogate(args): diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 0db9e340e..0722b93f4 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -18,6 +18,7 @@ import numpy as np import torch +from library.device_utils import clean_memory, get_preferred_device from library.ipex_interop import init_ipex init_ipex() @@ -640,8 +641,7 @@ def __call__( init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -780,8 +780,7 @@ def __call__( if vae_batch_size >= batch_size: image = self.vae.decode(latents.to(self.vae.dtype)).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( @@ -796,8 +795,7 @@ def __call__( # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() if output_type == "pil": # image = self.numpy_to_pil(image) @@ -1497,7 +1495,7 @@ def __getattr__(self, item): # scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 15a70678f..3eae20441 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -10,6 +10,7 @@ import numpy as np import torch +from library.device_utils import get_preferred_device from library.ipex_interop import init_ipex init_ipex() @@ -85,7 +86,7 @@ def get_timestep_embedding(x, outdim): guidance_scale = 7 seed = None # 1 - DEVICE = "cuda" + DEVICE = get_preferred_device() DTYPE = torch.float16 # bfloat16 may work parser = argparse.ArgumentParser() diff --git a/sdxl_train.py b/sdxl_train.py index a3f6f3a17..78cfaf495 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -1,7 +1,6 @@ # training with captions import argparse -import gc import math import os from multiprocessing import Value @@ -11,6 +10,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -252,9 +252,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() @@ -407,8 +405,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 7a88feb84..95b755f18 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -2,7 +2,6 @@ # training code for ControlNet-LLLite with passing cond_image to U-Net's forward import argparse -import gc import json import math import os @@ -15,6 +14,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -164,9 +164,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() @@ -291,8 +289,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index b94bf5c1b..fd24898c4 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -12,6 +11,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -163,9 +163,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() @@ -264,8 +262,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 5d363280d..af0c8d1d7 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,6 +1,7 @@ import argparse import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -65,8 +66,7 @@ def cache_text_encoder_outputs_if_needed( org_unet_device = unet.device vae.to("cpu") unet.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() # When TE is not be trained, it will not be prepared so we need to use explicit autocast with accelerator.autocast(): @@ -81,8 +81,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() if not args.lowram: print("move vae and unet back to original device") diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index ab1fa3390..27d13ef61 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -15,6 +15,8 @@ from tqdm import tqdm from PIL import Image +from library.device_utils import get_preferred_device + class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): @@ -255,7 +257,7 @@ def create_upscaler(**kwargs): # another interface: upscale images with a model for given images from command line def upscale_images(args: argparse.Namespace): - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + DEVICE = get_preferred_device() us_dtype = torch.float16 # TODO: support fp32/bf16 os.makedirs(args.output_dir, exist_ok=True) diff --git a/train_controlnet.py b/train_controlnet.py index 7b0b2bbfe..e6bea2c9f 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -12,6 +11,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -219,9 +219,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_db.py b/train_db.py index 888cad25e..daeb6d668 100644 --- a/train_db.py +++ b/train_db.py @@ -1,7 +1,6 @@ # DreamBooth training # XXX dropped option: fine_tune -import gc import argparse import itertools import math @@ -12,6 +11,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -138,9 +138,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index 8d102ae8f..9aabd4d7c 100644 --- a/train_network.py +++ b/train_network.py @@ -1,6 +1,5 @@ import importlib import argparse -import gc import math import os import sys @@ -14,6 +13,7 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -266,9 +266,7 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 441c1e00b..821cfe786 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,5 +1,4 @@ import argparse -import gc import math import os from multiprocessing import Value @@ -8,6 +7,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -363,9 +363,7 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7046a4808..ecd6d087f 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -1,6 +1,5 @@ import importlib import argparse -import gc import math import os import toml @@ -9,6 +8,7 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory from library.ipex_interop import init_ipex init_ipex() @@ -286,9 +286,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone()