Skip to content

Commit

Permalink
rt_gene_inpainting cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Tobias-Fischer committed Aug 5, 2020
1 parent 7c616d0 commit ece1e8d
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 446 deletions.
42 changes: 19 additions & 23 deletions rt_gene_inpainting/GAN_train.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from __future__ import print_function, division, absolute_import

from my_utils import *
from models import LSGAN_Model, set_trainability
import numpy as np
import tensorflow as tf
from utils import *
import os
from glob import glob

from tensorflow.keras.models import load_model
import numpy as np
from tensorflow.keras.callbacks import TensorBoard

from datetime import datetime

from tqdm import tqdm, tnrange, tqdm_notebook
from tqdm.auto import tqdm
from utils import PRL_data_image_load, write_log, GAN_plot_images


class GAN_train(object):
Expand Down Expand Up @@ -87,30 +81,30 @@ def train(self, num_epoch=2000, batch_size=256, save_interval=0):

# Initial Update Discriminator
set_trainability(self.discriminator, True)
d_loss_real = self.discriminator_cost.train_on_batch(images_train, np.ones([batch_size,1]))
d_loss_fake = self.discriminator_cost.train_on_batch(images_fake, np.zeros([batch_size,1]))
d_loss_real = self.discriminator_cost.train_on_batch(images_train, np.ones([batch_size, 1]))
d_loss_fake = self.discriminator_cost.train_on_batch(images_fake, np.zeros([batch_size, 1]))
d_loss = d_loss_real + d_loss_fake

# TRAINING STEPS ------------------------------------------------------
print('========= Main LSGAN Training ==========')
num_batch = self.num_total_data // batch_size

for e in xrange(num_epoch):
for e in range(num_epoch):
shuffled_sample_idx = np.random.permutation(self.num_total_data)

for b in tqdm(xrange(num_batch)):
batch_sample_idx = shuffled_sample_idx[b*batch_size:(b+1)*batch_size];
images_train = PRL_data_image_load(self.data, sample_idx=sample_idx)
for b in tqdm(range(num_batch)):
batch_sample_idx = shuffled_sample_idx[b*batch_size:(b+1)*batch_size]

images_train = PRL_data_image_load(self.data, sample_idx=batch_sample_idx)

# noise = np.random.normal(0.0, 1.0, size=[batch_size, self.noise_dim])
noise = np.random.uniform(-1.0, 1.0, size=[batch_size, self.noise_dim])
images_fake = self.generator.predict(noise)

# Update Discriminator
set_trainability(self.discriminator, True)
d_loss_real = self.discriminator_cost.train_on_batch(images_train, np.ones([batch_size,1]))
d_loss_fake = self.discriminator_cost.train_on_batch(images_fake, np.zeros([batch_size,1]))
d_loss_real = self.discriminator_cost.train_on_batch(images_train, np.ones([batch_size, 1]))
d_loss_fake = self.discriminator_cost.train_on_batch(images_fake, np.zeros([batch_size, 1]))
d_loss = d_loss_real + d_loss_fake

# Update Generator
Expand All @@ -136,13 +130,15 @@ def train(self, num_epoch=2000, batch_size=256, save_interval=0):
log_mesg = "%s: [D loss: %f, acc: %f]" % (log_mesg, d_loss[0], d_loss[1])
log_mesg = "%s [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
print(log_mesg)

GAN_plot_images(generator = self.generator, x_train=self.x_train, dataset = self.dataset, save2file=True, samples=sample_noise_input.shape[0], noise=sample_noise_input, step=(e + 1), folder_path=self.dataset_path_GAN_samples)
GAN_plot_images(generator = self.generator, x_train=self.x_train, dataset = self.dataset, save2file=False, samples=sample_noise_input.shape[0], noise=sample_noise_input, step=(e + 1))

GAN_plot_images(generator=self.generator, x_train=self.x_train, dataset=self.dataset,
save2file=True, samples=sample_noise_input.shape[0], noise=sample_noise_input,
step=(e + 1), folder_path=self.dataset_path_GAN_samples)
GAN_plot_images(generator=self.generator, x_train=self.x_train, dataset=self.dataset,
save2file=False, samples=sample_noise_input.shape[0], noise=sample_noise_input,
step=(e + 1))

# Save trained models
self.adversarial_cost.save(self.dataset_path_GAN_model+"/GAN_"+str(e+1)+"_"+self.dataset+"_forganECCV_adversarial_model_uniform.h5")
self.discriminator.save(self.dataset_path_GAN_model+"/GAN_"+str(e+1)+"_"+self.dataset+"_forganECCV_discriminator_uniform.h5")
self.generator.save(self.dataset_path_GAN_model+"/GAN_"+str(e+1)+"_"+self.dataset+"_forganECCV_generator_uniform.h5")


7 changes: 2 additions & 5 deletions rt_gene_inpainting/GAN_train_run.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"outputs": [],
"source": [
"from GAN_train import GAN_train\n",
"from my_utils import ElapsedTimer\n",
"\n",
"import tensorflow as tf\n",
"\n",
Expand All @@ -21,9 +20,7 @@
" subject = 's000'\n",
" gan_train = GAN_train(dataset_folder_path, subject)\n",
"\n",
" timer = ElapsedTimer()\n",
" gan_train.train(num_epoch=100, batch_size=96, save_interval=1)\n",
" timer.elapsed_time()"
" gan_train.train(num_epoch=100, batch_size=96, save_interval=1)"
]
},
{
Expand Down Expand Up @@ -55,4 +52,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
65 changes: 27 additions & 38 deletions rt_gene_inpainting/GlassesCompletion.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
from __future__ import print_function, division, absolute_import

from my_utils import *
from models import LSGAN_Model, Completion_Model, set_trainability
import numpy as np
import tensorflow as tf
from models import LSGAN_Model, Completion_Model
from utils import *
import os
from glob import glob
from tqdm import tqdm, tnrange
from tqdm import tqdm

from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras import backend as K

import matplotlib.pyplot as plt
from datetime import datetime

import external.poissonblending as blending


Expand Down Expand Up @@ -72,13 +64,12 @@ def __init__(self, dataset_common_folder_path, dataset):

print('Done Loading Pre-trained Network!')


def image_completion_random_search(self, nIter=1000, GPU_ID="0"):
filename_total_face = sorted(glob(os.path.join(self.path_images, 'face_*.png')))

self.num_total_data = len(filename_total_face)
num_total_data = len(filename_total_face)

print(self.num_total_data)
print(num_total_data)

print('=======================================================')

Expand Down Expand Up @@ -115,7 +106,7 @@ def image_completion_random_search(self, nIter=1000, GPU_ID="0"):

print(self.path_completion)

for img_idx in tqdm(range(0, self.num_total_data)):
for img_idx in tqdm(range(0, num_total_data)):
filename_face = filename_total_face[img_idx]
filename_index = filename_face[-14:-8]
filename_mask = self.folder_path_images + '/original/mask/mask_' + filename_index + '_overlay.png'
Expand All @@ -124,90 +115,88 @@ def image_completion_random_search(self, nIter=1000, GPU_ID="0"):
if os.path.isfile(filename_out):
continue

data_face = imread_PRL(filename_face, is_grayscale = False)
data_face = imread_PRL(filename_face, is_grayscale=False)
image_face = np.array(data_face).astype(np.float32)

data_mask = imread_PRL(filename_mask, is_grayscale = True)
data_mask = imread_PRL(filename_mask, is_grayscale=True)
image_mask = np.array(data_mask).astype(np.float32)

# Sample index
sample_num = 1
sample_noise_input = np.random.uniform(-1.0, 1.0, size=[sample_num, self.noise_dim])
self.sample_num = sample_num
# sample_noise_input = np.random.uniform(-1.0, 1.0, size=[sample_num, self.noise_dim])

# mask generation
mask = self.mask_PRL_Glasses(image_mask)

masked_images = np.multiply(image_face, mask)
# masked_images = np.multiply(image_face, mask)

y = np.ones([sample_num, 1])
# y = np.ones([sample_num, 1])
zhats = np.random.uniform(-1.0, 1.0, size=[sample_num, self.noise_dim])

loss_buf = 0
# loss_buf = 0

l_buf = 10000000
zhats_buf = zhats
final_iter = 0
# final_iter = 0

for j in range(nIter):
zhats_search = np.random.uniform(-1.0, 1.0, size=[sample_num, self.noise_dim])
G_imgs = self.generator.predict(zhats_search)
G_imgs = np.squeeze(G_imgs)
g, l, lc, lp = sess.run([gradients, loss, loss_contextual, loss_perceptual], feed_dict={complete_loss_model.input:zhats_search, mask_tensor:mask, images_tensor: image_face, G_images_tensor: G_imgs})
g, l, lc, lp = sess.run([gradients, loss, loss_contextual, loss_perceptual], feed_dict={complete_loss_model.input: zhats_search, mask_tensor: mask, images_tensor: image_face, G_images_tensor: G_imgs})

if np.sum(l) < l_buf:
l_buf = np.sum(l)
zhats_buf = zhats_search
final_iter = j
# final_iter = j

zhats = zhats_buf
zhats = zhats_buf
G_imgs = self.generator.predict(zhats)
G_imgs = np.squeeze(G_imgs)

#--------------------------------------------------------------
# --------------------------------------------------------------
# Generate completed images
inv_masked_hat_images = np.multiply(G_imgs, 1.0-mask)
completed = masked_images + inv_masked_hat_images
# inv_masked_hat_images = np.multiply(G_imgs, 1.0-mask)
# completed = masked_images + inv_masked_hat_images

filename = self.path_completion+'/hats/' + filename_index + '.png'
scipy.misc.imsave(filename, (G_imgs + 1) / 2)

# Poisson Blending
image_out = self.iminvtransform(G_imgs)
image_in = self.iminvtransform(image_face)
image_in = self.iminvtransform(image_face)

try:
image_out = self.poissonblending(image_in, image_out, mask)
filename = self.path_completion+'/blended/' + filename_index + '.png'
filename = self.path_completion+'/blended/' + filename_index + '.png'
scipy.misc.imsave(filename, image_out)
except:
print("Error occurred while blending: " + str(filename_index))
pass

sess.close()


def mask_PRL_Glasses(self, mask_images):
mask = np.ones(self.image_shape)

for ir in range(self.img_rows):
for ic in range(self.img_cols):
# if mask_images[ir,ic] >= (127.5/127.5-1):
if mask_images[ir,ic] > (0/127.5-1):
mask[ir,ic,:] = 0
if mask_images[ir, ic] > (0/127.5-1):
mask[ir, ic, :] = 0

return mask

def poissonblending(self, img1, img2, mask):
@staticmethod
def poissonblending(img1, img2, mask):
"""Helper: interface to external poisson blending"""
return blending.blend(img1, img2, 1 - mask)
return blending.blend(img1, img2, 1 - mask)

def iminvtransform(self, img):
@staticmethod
def iminvtransform(img):
"""Helper: Rescale pixel value ranges to 0 and 1"""
return (np.array(img) + 1.0) / 2.0


def loss_LSGAN(y_true, y_pred):
return K.mean(K.square(y_pred-y_true), axis=-1)/2


1 change: 0 additions & 1 deletion rt_gene_inpainting/GlassesCompletion_run.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python

from GlassesCompletion import GlassesCompletion
from my_utils import ElapsedTimer

import tensorflow as tf

Expand Down
4 changes: 3 additions & 1 deletion rt_gene_inpainting/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ This work was supported in part by the Samsung Global Research Outreach program,
More information can be found on the Personal Robotic Lab's website: <https://www.imperial.ac.uk/personal-robotics/software/>.

## Requirements
`pip install tensorflow-gpu keras numpy scipy tqdm matplotlib pyamg Pillow`
- pip: `pip install tensorflow-gpu keras numpy scipy<=1.2.1 tqdm matplotlib pyamg`
- conda: `conda install tensorflow-gpu keras numpy scipy<=1.2.1 tqdm matplotlib pyamg`

## Inpainting source code
This code was used to inpaint the region covered by the eyetracking glasses. There are two parts:
Expand All @@ -41,6 +42,7 @@ In `GAN_train_run.ipynb` and `GlassesCompletion_run.py` the `dataset_folder_path

## List of libraries
- [./external/poissonblending.py](./external/poissonblending.py): [MIT License](https://opensource.org/licenses/MIT); [Link to GitHub](https://github.com/parosky/poissonblending)
- Some code taken from [DC-GAN](https://github.com/Newmu/dcgan_code): [MIT License](https://github.com/Newmu/dcgan_code/blob/master/LICENSE); [Link to GitHub](https://github.com/Newmu/dcgan_code)
- Tensorflow; [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0), [Link to website](http://tensorflow.org/)
- Keras; [MIT License](https://opensource.org/licenses/MIT), [Link to website](https://keras.io)

Loading

0 comments on commit ece1e8d

Please sign in to comment.