-
Notifications
You must be signed in to change notification settings - Fork 4
/
sd.py
133 lines (113 loc) · 4.87 KB
/
sd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import time
import torch # , tomesd
import torch.nn as nn
from transformers import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers import (
DiffusionPipeline
)
logging.set_verbosity_error()
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
class StableDiffusion(nn.Module):
def __init__(self,
device,
mode='geometry',
text='',
add_directional_text=False,
batch=1,
guidance_weight=100,
sds_weight_strategy=0,
early_time_step_range=[0.02, 0.5],
late_time_step_range=[0.02, 0.5],
sd_version='2.1',
negative_text=''):
super().__init__()
self.device = device
self.mode = mode
self.text = text
self.add_directional_text = add_directional_text
self.batch = batch
self.sd_version = sd_version
print(f'[INFO] loading stable diffusion...')
if self.sd_version == '2.1':
model_key = "stabilityai/stable-diffusion-2-1-base"
elif self.sd_version == '2.0':
model_key = "stabilityai/stable-diffusion-2-base"
elif self.sd_version == '1.5':
model_key = "runwayml/stable-diffusion-v1-5"
pipeline = DiffusionPipeline.from_pretrained(
model_key, torch_dtype=torch.float16).to(self.device)
self.vae = pipeline.vae
self.tokenizer = pipeline.tokenizer
self.text_encoder = pipeline.text_encoder
self.unet = pipeline.unet
if is_xformers_available():
self.unet.enable_xformers_memory_efficient_attention()
self.negative_text = negative_text
if add_directional_text:
self.text_z = []
self.uncond_z = []
for d in ['front', 'side', 'back', 'side']:
text = f"{self.text}, {d} view"
# text = f"{d} view of {self.text}"
negative_text = f"{self.negative_text}"
if d == 'back':
negative_text += "face"
text_z = self.get_text_embeds([text], batch=1)
uncond_z = self.get_uncond_embeds([negative_text], batch=1)
self.text_z.append(text_z)
self.uncond_z.append(uncond_z)
self.text_z = torch.cat(self.text_z)
self.uncond_z = torch.cat(self.uncond_z)
else:
self.text_z = self.get_text_embeds([self.text], batch=self.batch)
self.uncond_z = self.get_uncond_embeds(
[self.negative_text], batch=self.batch)
del self.text_encoder
self.scheduler = pipeline.scheduler
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step_early = int(
self.num_train_timesteps * early_time_step_range[0])
self.max_step_early = int(
self.num_train_timesteps * early_time_step_range[1])
self.min_step_late = int(
self.num_train_timesteps * late_time_step_range[0])
self.max_step_late = int(
self.num_train_timesteps * late_time_step_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(
self.device) # for convenience
self.guidance_weight = guidance_weight
self.sds_weight_strategy = sds_weight_strategy
print(f'[INFO] loaded stable diffusion!')
def get_text_embeds(self, prompt, batch=1):
text_input = self.tokenizer(
prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
with torch.no_grad():
text_embeddings = self.text_encoder(
text_input.input_ids.to(self.device))[0]
if batch > 1:
text_embeddings = text_embeddings.repeat(batch, 1, 1)
return text_embeddings
def get_uncond_embeds(self, negative_prompt, batch):
uncond_input = self.tokenizer(negative_prompt, padding='max_length',
max_length=self.tokenizer.model_max_length, return_tensors='pt')
with torch.no_grad():
uncond_embeddings = self.text_encoder(
uncond_input.input_ids.to(self.device))[0]
if batch > 1:
uncond_embeddings = uncond_embeddings.repeat(batch, 1, 1)
return uncond_embeddings
def encode_imgs(self, imgs):
# imgs: [B, 3, H, W]
if self.mode == 'appearance_modeling':
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs