Skip to content

Commit

Permalink
feat(ml): controle over image generation with diffusion in-painting m…
Browse files Browse the repository at this point in the history
…odel
  • Loading branch information
beniz committed Nov 8, 2022
1 parent 8080036 commit 0a5ed86
Showing 1 changed file with 56 additions and 44 deletions.
100 changes: 56 additions & 44 deletions scripts/gen_single_image_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from data.online_creation import fill_mask_with_random, fill_mask_with_color


def load_model(modelpath, model_in_file, device):
def load_model(modelpath, model_in_file, device, sampling_steps):
train_json_path = modelpath + "/train_config.json"
with open(train_json_path, "r") as jsonf:
train_json = json.load(jsonf)

opt = TrainOptions().parse_json(train_json)
if opt.model_multimodal:
opt.model_input_nc += opt.train_mm_nz
Expand All @@ -27,6 +28,11 @@ def load_model(modelpath, model_in_file, device):
model.eval()
model.load_state_dict(torch.load(modelpath + "/" + model_in_file))

# sampling steps
if sampling_steps > 0:
model.denoise_fn.beta_schedule["test"]["n_timestep"] = sampling_steps
model.denoise_fn.set_new_noise_schedule("test")

model = model.to(device)
return model, opt

Expand All @@ -36,41 +42,79 @@ def load_model(modelpath, model_in_file, device):
"--model-in-file", help="file path to generator model (.pth file)", required=True
)

parser.add_argument("--img-size", default=256, type=int, help="square image size")
parser.add_argument("--img-width", default=-1, type=int, help="image width")
parser.add_argument("--img-height", default=-1, type=int, help="image height")

parser.add_argument("--img-in", help="image to transform", required=True)
parser.add_argument(
"--mask-in", help="mask used for image transformation", required=True
"--mask-in", help="mask used for image transformation", required=False
)
parser.add_argument("--bbox-in", help="bbox file used for masking")
parser.add_argument("--img-out", help="transformed image", required=True)
parser.add_argument(
"--sampling-steps", default=-1, type=int, help="number of sampling steps"
)
parser.add_argument("--cpu", action="store_true", help="whether to use CPU")
parser.add_argument("--gpuid", type=int, default=0, help="which GPU to use")
parser.add_argument(
"--seed", type=int, default=-1, help="random seed for reproducibility"
)
args = parser.parse_args()

# seed
if args.seed >= 0:
torch.manual_seed(args.seed)

# loading model
modelpath = args.model_in_file.replace(os.path.basename(args.model_in_file), "")
print("modelpath=", modelpath)

if not args.cpu:
device = torch.device("cuda:" + str(args.gpuid))
model, opt = load_model(modelpath, os.path.basename(args.model_in_file), device)
model, opt = load_model(
modelpath, os.path.basename(args.model_in_file), device, args.sampling_steps
)

# reading image
img = cv2.imread(args.img_in)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = cv2.imread(args.mask_in, 0)

# preprocessing
# reading the mask
if args.mask_in:
mask = cv2.imread(args.mask_in, 0)

bboxes = []
if args.bbox_in:
mask = np.zeros(img.shape[:2], dtype=np.uint8)
with open(args.bbox_in, "r") as bboxf:
for line in bboxf:
elts = line.rstrip().split()
bboxes.append([int(elts[1]), int(elts[2]), int(elts[3]), int(elts[4])])
for bbox in bboxes:
mask[bbox[1] : bbox[3], bbox[0] : bbox[2]] = np.full(
(bbox[3] - bbox[1], bbox[2] - bbox[0]), 1
) # ymin:ymax, xmin:xmax, ymax-ymin, xmax-xmin

if args.img_width or args.img_height > 0:
img = cv2.resize(img, (args.img_width, args.img_height))
mask = cv2.resize(mask, (args.img_width, args.img_height))

# preprocessing to torch
totensor = transforms.ToTensor()
resize = transforms.Resize(args.img_size)
tranlist = [
totensor,
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
resize,
# resize,
]
# if args.img_size > 0:
# resize = transforms.Resize(args.img_size)
# tranlist.append(resize)

tran = transforms.Compose(tranlist)
img_tensor = tran(img).clone().detach()

mask = torch.from_numpy(np.array(mask, dtype=np.int64)).unsqueeze(0)
mask = resize(mask).clone().detach()
# mask = resize(mask).clone().detach()

if not args.cpu:
img_tensor = img_tensor.to(device).clone().detach()
Expand Down Expand Up @@ -103,11 +147,11 @@ def load_model(modelpath, model_in_file, device):
)


print("outtensor", out_tensor.shape, "visu", visu.shape)
# print("outtensor", out_tensor.shape, "visu", visu.shape)

temp = img_tensor - out_tensor
print(temp.mean(), temp.min(), temp.max())
print(visu.shape)
# print(temp.mean(), temp.min(), temp.max())
# print(visu.shape)

# out_tensor = visu[-1:]

Expand All @@ -116,44 +160,12 @@ def load_model(modelpath, model_in_file, device):
img_np = img_tensor.detach().data.cpu().float().numpy()[0]
cond_image = cond_image.detach().data.cpu().float().numpy()[0]
# cond_image = torch.randn_like(cond_image)
visu = visu.detach().data.cpu().float().numpy()
visu1 = visu[1]
visu2 = visu[2]
visu0 = visu[0]

temp = out_img - img_np
print("np", temp.mean(), temp.min(), temp.max())

out_img = (np.transpose(out_img, (1, 2, 0)) + 1) / 2.0 * 255.0
img_np = (np.transpose(img_np, (1, 2, 0)) + 1) / 2.0 * 255.0
cond_image = (np.transpose(cond_image, (1, 2, 0)) + 1) / 2.0 * 255.0
visu0 = (np.transpose(visu0, (1, 2, 0)) + 1) / 2.0 * 255.0
visu1 = (np.transpose(visu1, (1, 2, 0)) + 1) / 2.0 * 255.0
visu2 = (np.transpose(visu2, (1, 2, 0)) + 1) / 2.0 * 255.0
print(out_img)
print(img_np)

temp = out_img - img_np
print("np", temp.mean(), temp.min(), temp.max())

out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(args.img_out, out_img)

img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
cv2.imwrite("/data1/pnsuau/checkpoints/test_palette_4/img_np.jpg", img_np)

cond_image = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
cv2.imwrite("/data1/pnsuau/checkpoints/test_palette_4/cond_image.jpg", cond_image)


visu0 = cv2.cvtColor(visu0, cv2.COLOR_RGB2BGR)
cv2.imwrite("/data1/pnsuau/checkpoints/test_palette_4/visu0.jpg", visu0)

visu1 = cv2.cvtColor(visu1, cv2.COLOR_RGB2BGR)
cv2.imwrite("/data1/pnsuau/checkpoints/test_palette_4/visu1.jpg", visu1)

visu2 = cv2.cvtColor(visu2, cv2.COLOR_RGB2BGR)
cv2.imwrite("/data1/pnsuau/checkpoints/test_palette_4/visu2.jpg", visu2)


print("Successfully generated image ", args.img_out)

0 comments on commit 0a5ed86

Please sign in to comment.