Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple variant mechanism #71

Merged
merged 6 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions ldm/models/diffusion/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@


class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
setattr(self, name, attr)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
Expand Down
7 changes: 6 additions & 1 deletion ldm/models/diffusion/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@


class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
def __init__(self, model, schedule="linear", device="cuda", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device

def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))

setattr(self, name, attr)

def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
Expand Down
17 changes: 9 additions & 8 deletions ldm/simplet2i.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
import os
from omegaconf import OmegaConf
from PIL import Image
import PIL
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange, repeat
Expand Down Expand Up @@ -158,7 +157,8 @@ def __init__(self,
@torch.no_grad()
def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,init_img=None,skip_normalize=False):
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,init_img=None,
skip_normalize=False,variants=None):
"""
Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
Expand Down Expand Up @@ -286,7 +286,8 @@ def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
@torch.no_grad()
def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None,
steps=None,seed=None,grid=None,individual=None,width=None,height=None,
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,skip_normalize=False):
cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,
skip_normalize=False,variants=None):
"""
Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...]
Expand Down Expand Up @@ -324,7 +325,7 @@ def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=Non
# PLMS sampler not supported yet, so ignore previous sampler
if self.sampler_name!='ddim':
print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler")
sampler = DDIMSampler(model)
sampler = DDIMSampler(model, device=self.device)
else:
sampler = self.sampler

Expand Down Expand Up @@ -462,9 +463,9 @@ def load_model(self):

msg = f'setting sampler to {self.sampler_name}'
if self.sampler_name=='plms':
self.sampler = PLMSSampler(self.model)
self.sampler = PLMSSampler(self.model, device=self.device)
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model)
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model,'dpm_2_ancestral')
elif self.sampler_name == 'k_dpm_2':
Expand All @@ -479,7 +480,7 @@ def load_model(self):
self.sampler = KSampler(self.model,'lms')
else:
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
self.sampler = PLMSSampler(self.model)
self.sampler = PLMSSampler(self.model, device=self.device)

print(msg)

Expand All @@ -506,7 +507,7 @@ def _load_img(self,path):
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from {path}")
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
image = image.resize((w, h), resample=PIL.Image.LANCZOS)
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
Expand Down
27 changes: 26 additions & 1 deletion scripts/dream.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import atexit
import os
import sys
import copy
from PIL import Image,PngImagePlugin

# readline unavailable on windows systems
Expand Down Expand Up @@ -175,9 +176,32 @@ def main_loop(t2i,parser,log,infile):
print(e)
continue


allVariantResults = []
if opt.variants is not None:
print(f"Generating {opt.variants} variant(s)...")
newopt = copy.deepcopy(opt)
newopt.variants = None
for r in results:
newopt.init_img = r[0]
print(f"\t generating variant for {newopt.init_img}")
for j in range(0, opt.variants):
try:
variantResults = t2i.img2img(**vars(newopt))
allVariantResults.append([newopt,variantResults])
except AssertionError as e:
print(e)
continue
print(f"{opt.variants} Variants generated!")

print("Outputs:")
write_log_message(t2i,opt,results,log)

if allVariantResults:
print("Variant outputs:")
for vr in allVariantResults:
write_log_message(t2i,vr[0],vr[1],log)


print("goodbye!")

Expand Down Expand Up @@ -307,6 +331,7 @@ def create_cmd_parser():
parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)")
parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)")
parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely")
parser.add_argument('-v','--variants',type=int,help="number of variants to generate of each image")
parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization")
return parser

Expand All @@ -315,7 +340,7 @@ def setup_readline():
readline.set_completer(Completer(['cd','pwd',
'--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b',
'--width','-W','--height','-H','--cfg_scale','-C','--grid','-g',
'--individual','-i','--init_img','-I','--strength','-f']).complete)
'--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete)
readline.set_completer_delims(" ")
readline.parse_and_bind('tab: complete')
load_history()
Expand Down
2 changes: 1 addition & 1 deletion src/k-diffusion