From cc4b4a56f2bf14290859c72b38f201850c4fe382 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Tue, 16 Jan 2024 15:01:59 +0200 Subject: [PATCH] Refactor device determination to function; add MPS fallback --- finetune/make_captions.py | 3 ++- finetune/make_captions_by_git.py | 4 ++-- finetune/prepare_buckets_latents.py | 4 +++- gen_img_diffusers.py | 4 ++-- library/device_utils.py | 33 +++++++++++++++++++++++++++-- networks/lora_diffusers.py | 4 +++- networks/lora_interrogator.py | 3 ++- sdxl_gen_img.py | 4 ++-- sdxl_minimal_inference.py | 5 ++++- tools/latent_upscaler.py | 4 +++- 10 files changed, 54 insertions(+), 14 deletions(-) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 074576bc2..524f80b9d 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -15,8 +15,9 @@ sys.path.append(os.path.dirname(__file__)) from blip.blip import blip_decoder, is_url import library.train_util as train_util +from library.device_utils import get_preferred_device -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = get_preferred_device() IMAGE_SIZE = 384 diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index b3c5cc423..2b650eb01 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -10,9 +10,9 @@ from transformers.generation.utils import GenerationMixin import library.train_util as train_util +from library.device_utils import get_preferred_device - -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +DEVICE = get_preferred_device() PATTERN_REPLACE = [ re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'), diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 1bccb1d3b..9d352dd6e 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -14,7 +14,9 @@ import library.model_util as model_util import library.train_util as train_util -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +from library.device_utils import get_preferred_device + +DEVICE = get_preferred_device() IMAGE_TRANSFORMS = transforms.Compose( [ diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 6d72a22e8..0215bf988 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -66,7 +66,7 @@ import numpy as np import torch -from library.device_utils import clean_memory +from library.device_utils import clean_memory, get_preferred_device try: import intel_extension_for_pytorch as ipex @@ -2330,7 +2330,7 @@ def __getattr__(self, item): scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: diff --git a/library/device_utils.py b/library/device_utils.py index 49af622bb..f2f4233d1 100644 --- a/library/device_utils.py +++ b/library/device_utils.py @@ -1,9 +1,38 @@ +import functools import gc import torch +try: + HAS_CUDA = torch.cuda.is_available() +except Exception: + HAS_CUDA = False + +try: + HAS_MPS = torch.backends.mps.is_available() +except Exception: + HAS_MPS = False + def clean_memory(): - if torch.cuda.is_available(): - torch.cuda.empty_cache() gc.collect() + if HAS_CUDA: + torch.cuda.empty_cache() + if HAS_MPS: + torch.mps.empty_cache() + + +@functools.lru_cache(maxsize=None) +def get_preferred_device() -> torch.device: + try: + import accelerate + device = accelerate.Accelerator().device + except Exception: + if HAS_CUDA: + device = torch.device("cuda") + elif HAS_MPS: + device = torch.device("mps") + else: + device = torch.device("cpu") + print(f"get_preferred_device() -> {device}") + return device diff --git a/networks/lora_diffusers.py b/networks/lora_diffusers.py index 47d75ac4d..0056ac78e 100644 --- a/networks/lora_diffusers.py +++ b/networks/lora_diffusers.py @@ -11,6 +11,8 @@ from transformers import CLIPTextModel import torch +from library.device_utils import get_preferred_device + def make_unet_conversion_map() -> Dict[str, str]: unet_conversion_map_layer = [] @@ -476,7 +478,7 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline import torch - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = get_preferred_device() parser = argparse.ArgumentParser() parser.add_argument("--model_id", type=str, default=None, help="model id for huggingface") diff --git a/networks/lora_interrogator.py b/networks/lora_interrogator.py index 0dc066fd1..83942d7ce 100644 --- a/networks/lora_interrogator.py +++ b/networks/lora_interrogator.py @@ -9,11 +9,12 @@ import library.model_util as model_util import lora +from library.device_utils import get_preferred_device TOKENIZER_PATH = "openai/clip-vit-large-patch14" V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +DEVICE = get_preferred_device() def interrogate(args): diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index 6b15a4a8c..972372d9a 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -18,7 +18,7 @@ import numpy as np import torch -from library.device_utils import clean_memory +from library.device_utils import clean_memory, get_preferred_device try: import intel_extension_for_pytorch as ipex @@ -1501,7 +1501,7 @@ def __getattr__(self, item): # scheduler.config.clip_sample = True # deviceを決定する - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # "mps"を考量してない + device = get_preferred_device() # custom pipelineをコピったやつを生成する if args.vae_slices: diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index 45b9edd65..7cf4ec8c5 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -9,6 +9,9 @@ from einops import repeat import numpy as np import torch + +from library.device_utils import get_preferred_device + try: import intel_extension_for_pytorch as ipex if torch.xpu.is_available(): @@ -87,7 +90,7 @@ def get_timestep_embedding(x, outdim): guidance_scale = 7 seed = None # 1 - DEVICE = "cuda" + DEVICE = get_preferred_device() DTYPE = torch.float16 # bfloat16 may work parser = argparse.ArgumentParser() diff --git a/tools/latent_upscaler.py b/tools/latent_upscaler.py index ab1fa3390..27d13ef61 100644 --- a/tools/latent_upscaler.py +++ b/tools/latent_upscaler.py @@ -15,6 +15,8 @@ from tqdm import tqdm from PIL import Image +from library.device_utils import get_preferred_device + class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=1): @@ -255,7 +257,7 @@ def create_upscaler(**kwargs): # another interface: upscale images with a model for given images from command line def upscale_images(args: argparse.Namespace): - DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + DEVICE = get_preferred_device() us_dtype = torch.float16 # TODO: support fp32/bf16 os.makedirs(args.output_dir, exist_ok=True)