Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ESRGAN Example #47

Merged
merged 21 commits into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions tensorflow_gan/examples/esrgan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## ESRGAN

### How to run
1. Run the setup instructions in [tensorflow_gan/examples/README.md](https://github.com/tensorflow/gan/blob/master/tensorflow_gan/examples/README.md#steps-to-run-an-example)
2. Run:
```
python esrgan/train.py
```

nivedwho marked this conversation as resolved.
Show resolved Hide resolved
The Notebook files for training ESRGAN on Google Colaboratory can be found [here](https://github.com/tensorflow/gan/blob/master/tensorflow_gan/examples/esrgan/colab_notebooks/)

### Description
The ESRGAN model proposed in the paper [ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks (Wang Xintao et al.)](https://arxiv.org/abs/1809.00219) performs the task of image super-resolution which is the process of reconstructing high resolution (HR) image from a given low resolution (LR) image. Here we have trained the ESRGAN model on the DIV2K dataset and the model is evaluated using TF-GAN.

### Results
<img src="images/result1.png" title="Example 1" width="540" />
<img src="images/result2.png" title="Example 2" width="540" />
84 changes: 84 additions & 0 deletions tensorflow_gan/examples/esrgan/data_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# coding=utf-8
# Copyright 2021 The TensorFlow GAN Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.data.experimental import AUTOTUNE


def random_flip(lr_img, hr_img):
""" Randomly flips LR and HR images for data augmentation."""
random = tf.random.uniform(shape=(), maxval=1)

return tf.cond(random<0.5,
lambda: (lr_img, hr_img),
lambda: (tf.image.flip_left_right(lr_img),
tf.image.flip_left_right(hr_img)))

def random_rotate(lr_img, hr_img):
""" Randomly rotates LR and HR images for data augmentation."""
random = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
return tf.image.rot90(lr_img, random), tf.image.rot90(hr_img, random)


def get_div2k_data(hparams,
name='div2k/bicubic_x4',
mode='train',
shuffle=True,
repeat_count=None):
""" Downloads and loads DIV2K dataset.
Args:
hparams : A named tuple to store different parameters.
name : Name of the dataset to be loaded using tfds.
mode : Either 'train' or 'valid'.
shuffle : Whether to shuffle the images in the dataset.
repeat_count : Repetition of data during training.
Returns:
A tf.data.Dataset with pairs of LR image and HR image tensors.

Raises:
TypeError : If the data directory(data_dir) is not specified.
"""
split = 'train' if mode == 'train' else 'validation'

def scale(image, *args):
hr_size = hparams.hr_dimension
scale = hparams.scale

hr_image = image
hr_image = tf.image.resize(hr_image, [hr_size, hr_size])
lr_image = tf.image.resize(hr_image, [hr_size//scale, hr_size//scale], method='bicubic')

hr_image = tf.clip_by_value(hr_image, 0, 255)
lr_image = tf.clip_by_value(lr_image, 0, 255)

return lr_image, hr_image

dataset = (tfds.load(name=name,
split=split,
data_dir=hparams.data_dir,
as_supervised=True)
.map(scale, num_parallel_calls=AUTOTUNE)
.cache())

if shuffle:
dataset = dataset.shuffle(
buffer_size=10000, reshuffle_each_iteration=True)

dataset = dataset.batch(hparams.batch_size)
dataset = dataset.repeat(repeat_count)
dataset = dataset.prefetch(buffer_size=AUTOTUNE)

return dataset
50 changes: 50 additions & 0 deletions tensorflow_gan/examples/esrgan/data_provider_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# coding=utf-8
# Copyright 2021 The TensorFlow GAN Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for tfgan.examples.esrgan.data_provider"""
import collections
from absl.testing import absltest
import tensorflow as tf
import data_provider


hparams = collections.namedtuple('hparams', ['hr_dimension',
'scale',
'batch_size',
'data_dir'])

class DataProviderTest(tf.test.TestCase, absltest.TestCase):
def setUp(self):
super(DataProviderTest, self).setUp()
self.hparams = hparams(256, 4, 32, '/content/')
self.dataset = data_provider.get_div2k_data(self.hparams)
self.mock_lr = tf.random.normal([32, 64, 64, 3])
self.mock_hr = tf.random.normal([32, 256, 256, 3])

def test_dataset(self):
self.assertIsInstance(self.dataset, tf.data.Dataset)
with self.cached_session() as sess:
lr_image, hr_image = next(iter(self.dataset))
sess.run(tf.compat.v1.global_variables_initializer())

self.assertEqual(type(self.mock_lr), type(lr_image))
self.assertEqual(self.mock_lr.shape, lr_image.shape)

self.assertEqual(type(self.mock_hr), type(hr_image))
self.assertEqual(self.mock_hr.shape, hr_image.shape)


if __name__ == '__main__':
tf.test.main()
57 changes: 57 additions & 0 deletions tensorflow_gan/examples/esrgan/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# coding=utf-8
# Copyright 2021 The TensorFlow GAN Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl import flags, logging, app
import tensorflow as tf
import eval_lib
import data_provider


flags.DEFINE_integer('batch_size', 16,
'The number of images in each batch.')
flags.DEFINE_integer('hr_dimension', 128,
'Dimension of a HR image.')
flags.DEFINE_integer('scale', 4,
'Factor by which LR images are downscaled.')
flags.DEFINE_string('model_dir', '/content/',
'Directory where the trained models are stored.')
flags.DEFINE_string('data_dir', '/content/datasets',
'Directory where dataset is stored.')
flags.DEFINE_integer('num_steps', 100,
'The number of steps for evaluation.')
flags.DEFINE_integer('num_inception_images', 16,
'The number of images passed for evaluation at each step.')
flags.DEFINE_string('image_dir', '/content/results',
'Directory to save generated images during evaluation.')
flags.DEFINE_boolean('eval_real_images', False,
'Whether Phase 1 training is done or not')

FLAGS = flags.FLAGS

def main(_):
nivedwho marked this conversation as resolved.
Show resolved Hide resolved
hparams = eval_lib.HParams(FLAGS.batch_size, FLAGS.hr_dimension,
FLAGS.scale, FLAGS.model_dir,
FLAGS.data_dir,FLAGS.num_steps,
FLAGS.num_inception_images,FLAGS.image_dir,
FLAGS.eval_real_images)

generator = tf.keras.models.load_model(FLAGS.model_dir +
'Phase_2/interpolated_generator')
data = data_provider.get_div2k_data(hparams, mode='valid')
eval_lib.evaluate(hparams, generator, data)

if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
app.run(main)
77 changes: 77 additions & 0 deletions tensorflow_gan/examples/esrgan/eval_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# coding=utf-8
# Copyright 2021 The TensorFlow GAN Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import tensorflow as tf
from absl import logging
import utils

hparams = collections.namedtuple('hparams', [
'batch_size', 'hr_dimension',
'scale', 'model_dir',
'data_dir', 'num_steps',
'num_inception_images', 'image_dir',
'eval_real_images'])

def evaluate(hparams, generator, data):
""" Runs an evaluation loop and calculates the mean FID,
Inception and PSNR scores observed on the validation dataset.

Args:
hparams: Parameters for evaluation.
generator : The trained generator network.
data : Validation DIV2K dataset.
"""
fid_metric = tf.keras.metrics.Mean()
inc_metric = tf.keras.metrics.Mean()
psnr_metric = tf.keras.metrics.Mean()
step = 0

for lr, hr in data.take(hparams.num_steps):
step += 1
# Generate fake images for evaluating the model
gen = generator(lr)

if step%hparams.num_steps//10 == 0:
utils.visualize_results(lr,
gen,
hr,
image_dir=hparams.image_dir,
step=step)

# Compute Frechet Inception Distance.
fid_score = utils.get_frechet_inception_distance(
hr, gen,
hparams.batch_size,
hparams.num_inception_images)
fid_metric(fid_score)

# Compute Inception Scores.
if hparams.eval_real_images:
inc_score = utils.get_inception_scores(hr,
hparams.batch_size,
hparams.num_inception_images)
else:
inc_score = utils.get_inception_scores(gen,
hparams.batch_size,
hparams.num_inception_images)
inc_metric(inc_score)

# Compute PSNR values.
psnr = utils.get_psnr(hr, gen)
psnr_metric(psnr)

logging.info('FID Score :{}\tInception Score :{}\tPSNR value :{}'.format(
fid_metric.result(), inc_metric.result(), psnr_metric.result()))
52 changes: 52 additions & 0 deletions tensorflow_gan/examples/esrgan/eval_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2021 The TensorFlow GAN Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for tfgan.examples.esrgan.eval"""

import collections

import tensorflow as tf
import eval_lib
import networks

hparams = collections.namedtuple('hparams', [
'num_steps', 'image_dir', 'batch_size', 'num_inception_images',
'eval_real_images', 'hr_dimension', 'scale', 'trunk_size'])

class EvalTest(tf.test.TestCase):
def setUp(self):
self.hparams = hparams(1, '/content/',
2, 2,
True, 256,
4, 11)

d = tf.data.Dataset.from_tensor_slices(tf.random.normal([2, 256, 256, 3]))
def lr(hr):
lr = tf.image.resize(hr, [64, 64], method='bicubic')
return lr, hr

d = d.map(lr)
d = d.batch(2)
self.mock_dataset = d
self.generator = networks.generator_network(self.hparams)

def test_eval(self):
self.assertIsNone(eval_lib.evaluate(self.hparams,
self.generator,
self.mock_dataset))

if __name__ == '__main__':
tf.test.main()

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading