forked from PixArt-alpha/PixArt-alpha
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference_lcm.py
157 lines (130 loc) · 6.74 KB
/
inference_lcm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import sys
from pathlib import Path
current_file_path = Path(__file__).resolve()
sys.path.insert(0, str(current_file_path.parent.parent))
import warnings
warnings.filterwarnings("ignore") # ignore warning
import re
import argparse
from datetime import datetime
from tqdm import tqdm
import torch
from torchvision.utils import save_image
from diffusers.models import AutoencoderKL
from diffusion.model.utils import prepare_prompt_ar
from tools.download import find_model
from diffusion.model.nets import PixArtMS_XL_2, PixArt_XL_2
from diffusion.model.t5 import T5Embedder
from diffusion.data.datasets import get_chunks
from diffusion.lcm_scheduler import LCMScheduler
from diffusion.data.datasets import ASPECT_RATIO_512_TEST, ASPECT_RATIO_1024_TEST
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--image_size', default=1024, type=int)
parser.add_argument('--t5_path', default='output/pretrained_models/t5_ckpts', type=str)
parser.add_argument('--tokenizer_path', default='output/pretrained_models/sd-vae-ft-ema', type=str)
parser.add_argument('--txt_file', default='asset/samples.txt', type=str)
parser.add_argument('--model_path', default='output/pretrained_models/PixArt-XL-2-1024x1024.pth', type=str)
parser.add_argument('--bs', default=1, type=int)
parser.add_argument('--cfg_scale', default=4.5, type=float)
parser.add_argument('--sample_steps', default=4, type=int)
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--dataset', default='custom', type=str)
parser.add_argument('--step', default=-1, type=int)
parser.add_argument('--save_name', default='test_sample', type=str)
return parser.parse_args()
def set_env(seed=0):
torch.manual_seed(seed)
torch.set_grad_enabled(False)
for _ in range(30):
torch.randn(1, 4, args.image_size, args.image_size)
@torch.inference_mode()
def visualize(items, bs, sample_steps, cfg_scale):
# 4. Prepare timesteps
scheduler.set_timesteps(sample_steps, 50)
timesteps = scheduler.timesteps
for chunk in tqdm(list(get_chunks(items, bs)), unit='batch'):
prompts = []
if bs == 1:
prompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(chunk[0], base_ratios, device=device, show=False) # ar for aspect ratio
if args.image_size == 1024:
latent_size_h, latent_size_w = int(hw[0, 0] // 8), int(hw[0, 1] // 8)
else:
hw = torch.tensor([[args.image_size, args.image_size]], dtype=torch.float, device=device).repeat(bs, 1)
ar = torch.tensor([[1.]], device=device).repeat(bs, 1)
latent_size_h, latent_size_w = latent_size, latent_size
prompts.append(prompt_clean.strip())
else:
hw = torch.tensor([[args.image_size, args.image_size]], dtype=torch.float, device=device).repeat(bs, 1)
ar = torch.tensor([[1.]], device=device).repeat(bs, 1)
prompts.append(prepare_prompt_ar(prompt, base_ratios, device=device, show=False)[0].strip())
latent_size_h, latent_size_w = latent_size, latent_size
with torch.no_grad():
caption_embs, emb_masks = t5.get_text_embeddings(prompts)
caption_embs = caption_embs.float()[:, None]
print('finish embedding')
# Create sampling noise:
n = len(prompts)
latents = torch.randn(n, 4, latent_size_h, latent_size_w, device=device)
model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks)
# 7. LCM MultiStep Sampling Loop:
for i, t in tqdm(list(enumerate(timesteps))):
ts = torch.full((bs,), t, device=device, dtype=torch.long)
# model prediction (v-prediction, eps, x)
model_pred = model(latents, ts, caption_embs, **model_kwargs)[:, :4]
# compute the previous noisy sample x_t -> x_t-1
latents, denoised = scheduler.step(model_pred, i, t, latents, return_dict=False)
samples = vae.decode(denoised / 0.18215).sample
torch.cuda.empty_cache()
# Save images:
os.umask(0o000) # file permission: 666; dir permission: 777
for i, sample in enumerate(samples):
save_path = os.path.join(save_root, f"{prompts[i][:100]}.jpg")
print("Saving path: ", save_path)
save_image(sample, save_path, nrow=1, normalize=True, value_range=(-1, 1))
if __name__ == '__main__':
args = get_args()
# Setup PyTorch:
seed = args.seed
set_env(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"
# only support fixed latent size currently
latent_size = args.image_size // 8
lewei_scale = {512: 1, 1024: 2} # trick for positional embedding interpolation
sample_steps = args.sample_steps
# Initalize Scheduler:
scheduler = LCMScheduler(beta_start=0.0001, beta_end=0.02, beta_schedule="linear", prediction_type="epsilon")
# model setting
if args.image_size == 512:
model = PixArt_XL_2(input_size=latent_size, lewei_scale=lewei_scale[args.image_size]).to(device)
else:
model = PixArtMS_XL_2(input_size=latent_size, lewei_scale=lewei_scale[args.image_size]).to(device)
print(f"Generating sample from ckpt: {args.model_path}")
state_dict = find_model(args.model_path)
del state_dict['state_dict']['pos_embed']
missing, unexpected = model.load_state_dict(state_dict['state_dict'], strict=False)
print('Missing keys: ', missing)
print('Unexpected keys', unexpected)
model.eval()
base_ratios = eval(f'ASPECT_RATIO_{args.image_size}_TEST')
vae = AutoencoderKL.from_pretrained(args.tokenizer_path).to(device)
t5 = T5Embedder(device="cuda", local_cache=True, cache_dir=args.t5_path, torch_dtype=torch.float)
work_dir = os.path.join(*args.model_path.split('/')[:-2])
work_dir = f'/{work_dir}' if args.model_path[0] == '/' else work_dir
# data setting
with open(args.txt_file, 'r') as f:
items = [item.strip() for item in f.readlines()]
# img save setting
try:
epoch_name = re.search(r'.*epoch_(\d+).*.pth', args.model_path).group(1)
step_name = re.search(r'.*step_(\d+).*.pth', args.model_path).group(1)
except Exception:
epoch_name = 'unknown'
step_name = 'unknown'
img_save_dir = os.path.join(work_dir, 'vis')
os.umask(0o000) # file permission: 666; dir permission: 777
os.makedirs(img_save_dir, exist_ok=True)
save_root = os.path.join(img_save_dir, f"{datetime.now().date()}_{args.dataset}_epoch{epoch_name}_step{step_name}_scale{args.cfg_scale}_step{sample_steps}_size{args.image_size}_bs{args.bs}_sampLCM_seed{seed}")
os.makedirs(save_root, exist_ok=True)
visualize(items, args.bs, sample_steps, args.cfg_scale)