Skip to content

Commit

Permalink
Refactor memory cleaning into a single function
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Jan 19, 2024
1 parent 39e9ade commit 1702ad8
Show file tree
Hide file tree
Showing 17 changed files with 66 additions and 71 deletions.
7 changes: 3 additions & 4 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# XXX dropped option: hypernetwork training

import argparse
import gc
import math
import os
from multiprocessing import Value
Expand All @@ -11,6 +10,8 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory

try:
import intel_extension_for_pytorch as ipex

Expand Down Expand Up @@ -163,9 +164,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

Expand Down
8 changes: 4 additions & 4 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
import numpy as np
import torch

from library.device_utils import clean_memory

try:
import intel_extension_for_pytorch as ipex

Expand Down Expand Up @@ -893,8 +895,7 @@ def __call__(
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
Expand Down Expand Up @@ -1052,8 +1053,7 @@ def __call__(
if vae_batch_size >= batch_size:
image = self.vae.decode(latents).sample
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
images = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append(
Expand Down
9 changes: 9 additions & 0 deletions library/device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import gc

import torch


def clean_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
5 changes: 2 additions & 3 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import gc
import math
import os
from typing import Optional
Expand All @@ -9,6 +8,7 @@
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from library.device_utils import clean_memory

TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
Expand Down Expand Up @@ -47,8 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()

return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
Expand Down
10 changes: 4 additions & 6 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob
import math
import os
Expand Down Expand Up @@ -68,6 +67,7 @@
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.original_unet import UNet2DConditionModel
from library.device_utils import clean_memory

# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
Expand Down Expand Up @@ -2267,8 +2267,7 @@ def cache_batch_latents(
info.latents_flipped = flipped_latent

# FIXME this slows down caching a lot, specify this as an option
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()


def cache_batch_text_encoder_outputs(
Expand Down Expand Up @@ -3994,8 +3993,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
unet.to(accelerator.device)
vae.to(accelerator.device)

gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()

return text_encoder, vae, unet, load_stable_diffusion_format
Expand Down Expand Up @@ -4804,7 +4802,7 @@ def sample_images_common(

# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
clean_memory()

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
Expand Down
3 changes: 1 addition & 2 deletions library/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import threading
from typing import *


def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
threading.Thread(target=f, args=args, kwargs=kwargs).start()
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ transformers==4.36.2
diffusers[torch]==0.25.0
ftfy==6.1.1
# albumentations==1.3.0
opencv-python==4.7.0.68
opencv-python>=4.7.0.68
einops==0.6.1
pytorch-lightning==1.9.0
# bitsandbytes==0.39.1
tensorboard==2.10.1
safetensors==0.3.1
safetensors>=0.3.1
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3
#altair==4.2.2
#easygui==0.98.3
toml==0.10.2
voluptuous==0.13.1
huggingface-hub==0.20.1
Expand Down
11 changes: 5 additions & 6 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import numpy as np
import torch

from library.device_utils import clean_memory

try:
import intel_extension_for_pytorch as ipex

Expand Down Expand Up @@ -645,8 +647,7 @@ def __call__(
init_latent_dist = self.vae.encode(init_image.to(self.vae.dtype)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
init_latents = []
for i in tqdm(range(0, min(batch_size, len(init_image)), vae_batch_size)):
init_latent_dist = self.vae.encode(
Expand Down Expand Up @@ -785,8 +786,7 @@ def __call__(
if vae_batch_size >= batch_size:
image = self.vae.decode(latents.to(self.vae.dtype)).sample
else:
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
images = []
for i in tqdm(range(0, batch_size, vae_batch_size)):
images.append(
Expand All @@ -801,8 +801,7 @@ def __call__(
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()

if output_type == "pil":
# image = self.numpy_to_pil(image)
Expand Down
10 changes: 4 additions & 6 deletions sdxl_train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# training with captions

import argparse
import gc
import math
import os
from multiprocessing import Value
Expand All @@ -11,6 +10,8 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory

try:
import intel_extension_for_pytorch as ipex

Expand Down Expand Up @@ -257,9 +258,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

Expand Down Expand Up @@ -412,8 +411,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
Expand Down
11 changes: 5 additions & 6 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# training code for ControlNet-LLLite with passing cond_image to U-Net's forward

import argparse
import gc
import json
import math
import os
Expand All @@ -14,6 +13,9 @@

from tqdm import tqdm
import torch

from library.device_utils import clean_memory

try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
Expand Down Expand Up @@ -166,9 +168,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

Expand Down Expand Up @@ -293,8 +293,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
Expand Down
11 changes: 5 additions & 6 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import gc
import json
import math
import os
Expand All @@ -11,6 +10,9 @@

from tqdm import tqdm
import torch

from library.device_utils import clean_memory

try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
Expand Down Expand Up @@ -165,9 +167,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

Expand Down Expand Up @@ -266,8 +266,7 @@ def train(args):
# move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16
text_encoder1.to("cpu", dtype=torch.float32)
text_encoder2.to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
else:
# make sure Text Encoders are on GPU
text_encoder1.to(accelerator.device)
Expand Down
8 changes: 4 additions & 4 deletions sdxl_train_network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
import torch

from library.device_utils import clean_memory

try:
import intel_extension_for_pytorch as ipex

Expand Down Expand Up @@ -70,8 +72,7 @@ def cache_text_encoder_outputs_if_needed(
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()

# When TE is not be trained, it will not be prepared so we need to use explicit autocast
with accelerator.autocast():
Expand All @@ -86,8 +87,7 @@ def cache_text_encoder_outputs_if_needed(

text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()

if not args.lowram:
print("move vae and unet back to original device")
Expand Down
7 changes: 3 additions & 4 deletions train_controlnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import gc
import json
import math
import os
Expand All @@ -12,6 +11,8 @@
from tqdm import tqdm
import torch

from library.device_utils import clean_memory

try:
import intel_extension_for_pytorch as ipex

Expand Down Expand Up @@ -224,9 +225,7 @@ def train(args):
accelerator.is_main_process,
)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
clean_memory()

accelerator.wait_for_everyone()

Expand Down
Loading

0 comments on commit 1702ad8

Please sign in to comment.