Skip to content

Commit

Permalink
added inference + made prelu channel wise + added checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreiMoraru123 committed Jul 28, 2023
1 parent 899f35a commit d16de1a
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 9 deletions.
6 changes: 3 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(

# PReLU and LeakyReLU have configurable parameters, so we can't just pass the strings to Keras
if activation == 'prelu':
self.activation_layer = layers.PReLU()
self.activation_layer = layers.PReLU(shared_axes=[1, 2])
elif activation == 'leakyrelu':
self.activation_layer = layers.LeakyReLU(0.2)
elif activation == 'tanh':
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(self, kernel_size: int = 3, n_channels: int = 64, scaling_factor: i
self.conv = layers.Conv2D(filters=n_channels * (scaling_factor ** 2),
kernel_size=kernel_size, padding='same')
self.scaling_factor = scaling_factor
self.prelu = layers.PReLU()
self.prelu = layers.PReLU(shared_axes=[1, 2])

def call(self, inputs: tf.Tensor) -> tf.Tensor:
"""
Expand Down Expand Up @@ -231,7 +231,7 @@ def initialize_with_srresnet(self, srresnet_checkpoint: Model):
:param srresnet_checkpoint: checkpoint filepath
"""
self.net = tf.keras.models.load_model(srresnet_checkpoint)
self.net = tf.saved_model.load(srresnet_checkpoint)

def call(self, low_res_images: tf.Tensor, training: bool = False) -> tf.Tensor:
"""
Expand Down
105 changes: 105 additions & 0 deletions resolve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# third-party imports
import tensorflow as tf
from PIL import Image, ImageDraw, ImageFont

# module imports
from transforms import ImageTransform


crop_size = 96 # crop size of target HR images
scaling_factor = 4 # the input LR images will be down-sampled from the target HR images by this factor

resnet = tf.saved_model.load("SuperResolutionResNet_9999")
resnet_inference = resnet.signatures["serving_default"]

generator = tf.saved_model.load("Generator_9999")
generator_inference = generator.signatures["serving_default"]

# Need this instance for conversions
transform = ImageTransform(split="train",
crop_size=crop_size,
lr_img_type='imagenet-norm',
hr_img_type='[-1, 1]',
scaling_factor=scaling_factor)


def super_resolve(img: str, halve: bool = False):
"""
Visualizes the super-resolved images from the SRResNet and SRGAN for comparison with the bicubic up-sampled image
and the original high-resolution (HR) image, as done in the paper.
:param img: filepath of the HR image
:param halve: halve each dimension of the HR image to make sure it's not greater than the dimensions of your screen?
For instance, for a 2160p HR image, the LR image will be of 540p (1080p/4) resolution. On a 1080p
screen, you will therefore be looking at a comparison between a 540p LR image and a 1080p SR/HR image
because your 1080p screen can only display the 2160p SR/HR image at a down-sampled 1080p. This is only
an APPARENT rescaling of 2x.
If you want to reduce HR resolution by a different extent, modify accordingly.
"""

# Load image, down-sample to obtain low-res version
hr_img = Image.open(img, mode="r")
hr_img = hr_img.convert('RGB')

if halve:
hr_img = hr_img.resize((int(hr_img.width / 2), int(hr_img.height / 2)),
Image.LANCZOS)

lr_img = hr_img.resize((int(hr_img.width / 4), int(hr_img.height / 4)),
Image.BICUBIC)

# Bicubic Upsampling
bicubic_img = lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)

lr_img = tf.expand_dims(transform.convert_image(lr_img, source='pil', target='imagenet-norm'), axis=0)

# Super-resolution (SR) with SRResNet
sr_img_srresnet = resnet_inference(lr_img)
sr_img_srresnet = tf.squeeze(sr_img_srresnet['output_0'])
sr_img_srresnet = transform.convert_image(sr_img_srresnet, source='[-1, 1]', target='pil')

sr_img_srgan = resnet_inference(lr_img)
sr_img_srgan = tf.squeeze(sr_img_srgan['output_0'])
sr_img_srgan = transform.convert_image(sr_img_srgan, source='[-1, 1]', target='pil')

# Create grid
margin = 40
grid_img = Image.new('RGB', (2 * hr_img.width + 3 * margin, 2 * hr_img.height + 3 * margin), (255, 255, 255))

# Drawer and font
draw = ImageDraw.Draw(grid_img)
font = ImageFont.load_default()

# Place bicubic-upsampled image
grid_img.paste(bicubic_img, (margin, margin))
text_size = font.getbbox("Bicubic")
draw.text(xy=[margin + bicubic_img.width / 2 - text_size[0] / 2, margin - text_size[1] - 5], text="Bicubic",
font=font,
fill='black')

# Place SRResNet image
grid_img.paste(sr_img_srresnet, (2 * margin + bicubic_img.width, margin))
text_size = font.getbbox("SRResNet")
draw.text(
xy=[2 * margin + bicubic_img.width + sr_img_srresnet.width / 2 - text_size[0] / 2, margin - text_size[1] - 5],
text="SRResNet", font=font, fill='black')

# Place SRGAN image
grid_img.paste(sr_img_srgan, (margin, 2 * margin + sr_img_srresnet.height))
text_size = font.getbbox("SRGAN")
draw.text(
xy=[margin + bicubic_img.width / 2 - text_size[0] / 2, 2 * margin + sr_img_srresnet.height - text_size[1] - 5],
text="SRGAN", font=font, fill='black')

# Place original HR image
grid_img.paste(hr_img, (2 * margin + bicubic_img.width, 2 * margin + sr_img_srresnet.height))
text_size = font.getbbox("Original HR")
draw.text(xy=[2 * margin + bicubic_img.width + sr_img_srresnet.width / 2 - text_size[0] / 2,
2 * margin + sr_img_srresnet.height - text_size[1] - 1], text="Original HR", font=font, fill='black')

# Save image
grid_img.save(img[:-5] + "_resolved" + ".png")


if __name__ == '__main__':
super_resolve("bird.jpeg", halve=False)
11 changes: 5 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
small_kernel_size = 3 # kernel size of the first and last convolutions which transform the inputs and outputs
n_channels = 64 # number of channels in-between, input and output channels for residual & subpixel conv blocks
n_blocks = 16 # number of residual blocks
srresnet_checkpoint = "srresnet" # filepath of the trained SRResNet checkpoint used for initialization
srresnet_checkpoint = "SuperResolutionResNet_9999" # trained SRResNet checkpoint used for initialization

# Discriminator parameters
kernel_size_d = 3 # kernel size in all convolutional blocks
Expand All @@ -34,12 +34,11 @@

# Learning parameters
checkpoint = None # path to model checkpoint, None if none
batch_size = 16 # batch size
batch_size = 1 # batch size
start_epoch = 0 # start at this epoch
epochs = 50 # number of training epochs
workers = 4 # number of workers for loading data in the DataLoader
epochs = 10_000 # number of training epochs
print_freq = 500 # print training status once every __ batches
lr = 1e-6 # learning rate
lr = 1e-5 # learning rate
beta = 1e-3 # the coefficient to weight the adversarial loss in the perceptual loss


Expand Down Expand Up @@ -95,4 +94,4 @@ def main(architecture_type: str = "resnet"):


if __name__ == "__main__":
main(architecture_type="resnet")
main(architecture_type="gan")
25 changes: 25 additions & 0 deletions trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ def compile(self):
else:
raise NotImplementedError("Trainer not defined for this type of architecture")

def save_checkpoint(self, name: str, epoch: int):
"""
Saves the model checkpoint at the given epoch.
:param name: model name
:param epoch: the given epoch for which to save the model.
"""

@tf.function(input_signature=[tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32)])
def serving_fn(image_batch):
"""
Serving function for saving the model.
:param image_batch: input image place-holder
:return: model inference function
"""
return self.architecture.model(image_batch)

tf.saved_model.save(self.architecture.model,
export_dir=f"{name}_{epoch}/",
signatures=serving_fn)

def train(self, start_epoch: int, epochs: int, batch_size: int, print_freq: int):
"""
Train the given model architecture.
Expand Down Expand Up @@ -87,6 +109,9 @@ def train(self, start_epoch: int, epochs: int, batch_size: int, print_freq: int
print(f'Epoch: [{epoch}][{i}/{epochs}]----'
f'Loss {loss:.4f}')

if (epoch + 1) % 10_000 == 0:
self.save_checkpoint(name=self.architecture.model.__class__.__name__, epoch=epoch)

@staticmethod
def create_dataset(
data_folder: str,
Expand Down

0 comments on commit d16de1a

Please sign in to comment.