diff --git a/fine_tune.py b/fine_tune.py index be61b3d16..1b62783a2 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -2,7 +2,6 @@ # XXX dropped option: hypernetwork training import argparse -import gc import math import os from multiprocessing import Value @@ -11,6 +10,8 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex @@ -163,9 +164,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index be43847a6..6d72a22e8 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -66,6 +66,8 @@ import numpy as np import torch +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex @@ -893,8 +895,7 @@ def __call__( init_latent_dist = self.vae.encode(init_image).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -1052,8 +1053,7 @@ def __call__( if vae_batch_size >= batch_size: image = self.vae.decode(latents).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( diff --git a/library/device_utils.py b/library/device_utils.py new file mode 100644 index 000000000..49af622bb --- /dev/null +++ b/library/device_utils.py @@ -0,0 +1,9 @@ +import gc + +import torch + + +def clean_memory(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 5ad748d15..e49a42e50 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -1,5 +1,4 @@ import argparse -import gc import math import os from typing import Optional @@ -9,6 +8,7 @@ from transformers import CLIPTokenizer from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline +from library.device_utils import clean_memory TOKENIZER1_PATH = "openai/clip-vit-large-patch14" TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" @@ -47,8 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory() accelerator.wait_for_everyone() return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info diff --git a/library/train_util.py b/library/train_util.py index fae429edf..2b924846e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -20,7 +20,6 @@ Union, ) from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs -import gc import glob import math import os @@ -68,6 +67,7 @@ # from library.attention_processors import FlashAttnProcessor # from library.hypernetwork import replace_attentions_for_hypernetwork from library.original_unet import UNet2DConditionModel +from library.device_utils import clean_memory # Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う TOKENIZER_PATH = "openai/clip-vit-large-patch14" @@ -2267,8 +2267,7 @@ def cache_batch_latents( info.latents_flipped = flipped_latent # FIXME this slows down caching a lot, specify this as an option - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() def cache_batch_text_encoder_outputs( @@ -3994,8 +3993,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio unet.to(accelerator.device) vae.to(accelerator.device) - gc.collect() - torch.cuda.empty_cache() + clean_memory() accelerator.wait_for_everyone() return text_encoder, vae, unet, load_stable_diffusion_format @@ -4804,7 +4802,7 @@ def sample_images_common( # clear pipeline and cache to reduce vram usage del pipeline - torch.cuda.empty_cache() + clean_memory() torch.set_rng_state(rng_state) if cuda_rng_state is not None: diff --git a/library/utils.py b/library/utils.py index 7d801a676..c96fcb144 100644 --- a/library/utils.py +++ b/library/utils.py @@ -1,6 +1,5 @@ import threading -from typing import * def fire_in_thread(f, *args, **kwargs): - threading.Thread(target=f, args=args, kwargs=kwargs).start() \ No newline at end of file + threading.Thread(target=f, args=args, kwargs=kwargs).start() diff --git a/requirements.txt b/requirements.txt index 8517d95ac..65c22258f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,15 +3,15 @@ transformers==4.36.2 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 -opencv-python==4.7.0.68 +opencv-python>=4.7.0.68 einops==0.6.1 pytorch-lightning==1.9.0 # bitsandbytes==0.39.1 tensorboard==2.10.1 -safetensors==0.3.1 +safetensors>=0.3.1 # gradio==3.16.2 -altair==4.2.2 -easygui==0.98.3 +#altair==4.2.2 +#easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 huggingface-hub==0.20.1 diff --git a/sdxl_gen_img.py b/sdxl_gen_img.py index ab5399842..6b15a4a8c 100755 --- a/sdxl_gen_img.py +++ b/sdxl_gen_img.py @@ -18,6 +18,8 @@ import numpy as np import torch +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex @@ -645,8 +647,7 @@ def __call__( init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist init_latents = init_latent_dist.sample(generator=generator) else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() init_latents = [] for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)): init_latent_dist = self.vae.encode( @@ -785,8 +786,7 @@ def __call__( if vae_batch_size >= batch_size: image = self.vae.decode(latents.to(self.vae.dtype)).sample else: - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() images = [] for i in tqdm(range(0, batch_size, vae_batch_size)): images.append( @@ -801,8 +801,7 @@ def __call__( # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() if output_type == "pil": # image = self.numpy_to_pil(image) diff --git a/sdxl_train.py b/sdxl_train.py index b4ce2770e..817c15831 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -1,7 +1,6 @@ # training with captions import argparse -import gc import math import os from multiprocessing import Value @@ -11,6 +10,8 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex @@ -257,9 +258,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() @@ -412,8 +411,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 4436dd3cd..e86443a40 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -2,7 +2,6 @@ # training code for ControlNet-LLLite with passing cond_image to U-Net's forward import argparse -import gc import json import math import os @@ -14,6 +13,9 @@ from tqdm import tqdm import torch + +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex if torch.xpu.is_available(): @@ -166,9 +168,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() @@ -293,8 +293,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 6ae5377ba..4629a7378 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -11,6 +10,9 @@ from tqdm import tqdm import torch + +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex if torch.xpu.is_available(): @@ -165,9 +167,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() @@ -266,8 +266,7 @@ def train(args): # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 text_encoder1.to("cpu", dtype=torch.float32) text_encoder2.to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() else: # make sure Text Encoders are on GPU text_encoder1.to(accelerator.device) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index d810ce7d4..822c1e153 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,6 +1,8 @@ import argparse import torch +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex @@ -70,8 +72,7 @@ def cache_text_encoder_outputs_if_needed( org_unet_device = unet.device vae.to("cpu") unet.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() # When TE is not be trained, it will not be prepared so we need to use explicit autocast with accelerator.autocast(): @@ -86,8 +87,7 @@ def cache_text_encoder_outputs_if_needed( text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + clean_memory() if not args.lowram: print("move vae and unet back to original device") diff --git a/train_controlnet.py b/train_controlnet.py index cc0eaab7a..89310116b 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -1,5 +1,4 @@ import argparse -import gc import json import math import os @@ -12,6 +11,8 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex @@ -224,9 +225,7 @@ def train(args): accelerator.is_main_process, ) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_db.py b/train_db.py index 14d9dff13..cda321d54 100644 --- a/train_db.py +++ b/train_db.py @@ -1,7 +1,6 @@ # DreamBooth training # XXX dropped option: fine_tune -import gc import argparse import itertools import math @@ -12,6 +11,8 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex @@ -143,9 +144,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_network.py b/train_network.py index c2b7fbdef..2aa23d0be 100644 --- a/train_network.py +++ b/train_network.py @@ -1,6 +1,5 @@ import importlib import argparse -import gc import math import os import sys @@ -14,6 +13,8 @@ import torch from torch.nn.parallel import DistributedDataParallel as DDP +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex @@ -271,9 +272,7 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 0e3912b1d..954715269 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -1,5 +1,4 @@ import argparse -import gc import math import os from multiprocessing import Value @@ -8,6 +7,8 @@ from tqdm import tqdm import torch +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex @@ -368,9 +369,7 @@ def train(self, args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 71b43549d..0417ba02a 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -1,6 +1,5 @@ import importlib import argparse -import gc import math import os import toml @@ -8,6 +7,9 @@ from tqdm import tqdm import torch + +from library.device_utils import clean_memory + try: import intel_extension_for_pytorch as ipex if torch.xpu.is_available(): @@ -288,9 +290,7 @@ def train(args): with torch.no_grad(): train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() + clean_memory() accelerator.wait_for_everyone()