Skip to content

Commit

Permalink
switched to global pooling, toggled training mode, modified params
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreiMoraru123 committed Jul 29, 2023
1 parent e21a490 commit 49e96cc
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 25 deletions.
8 changes: 4 additions & 4 deletions architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ def __init__(
@tf.function(jit_compile=True)
def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image) -> Tuple[Loss, Loss]:
with tf.GradientTape() as gen_tape:
super_res_images = self.model(low_res_images)
super_res_images = self.model(low_res_images, training=True)
super_res_images = self.transform.convert_image(super_res_images,
source='[-1, 1]',
target='imagenet-norm')
super_res_images_vgg_space = self.vgg(super_res_images)
high_res_images_vgg_space = self.vgg(tf.stop_gradient(high_res_images)) # does not get updated

super_res_discriminated = self.model2(super_res_images)
super_res_discriminated = self.model2(super_res_images, training=True)

content_loss = self.loss_fn(super_res_images_vgg_space, high_res_images_vgg_space)
adversarial_loss = self.loss_fn2(super_res_discriminated, tf.ones_like(super_res_discriminated))
Expand All @@ -123,8 +123,8 @@ def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image)
self.optimizer.apply_gradients(zip(gen_gradients, self.model.trainable_variables))

with tf.GradientTape() as dis_tape:
super_res_discriminated = self.model2(tf.stop_gradient(super_res_images)) # re-evaluate without gradients
high_res_discriminated = self.model2(high_res_images)
super_res_discriminated = self.model2(tf.stop_gradient(super_res_images), training=True)
high_res_discriminated = self.model2(high_res_images, training=True)

adversarial_loss = \
self.loss_fn2(super_res_discriminated, tf.zeros_like(super_res_discriminated)) + \
Expand Down
7 changes: 2 additions & 5 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,7 @@ def __init__(self, kernel_size: int = 3, n_channels: int = 64, n_blocks: int = 3

self.conv_blocks = tf.keras.Sequential(conv_blocks)

self.average_pool = layers.AveragePooling2D(pool_size=(4, 4))

self.flatten = layers.Flatten()
self.pool = layers.GlobalAveragePooling2D()

self.fc1 = layers.Dense(fc_size)

Expand All @@ -291,8 +289,7 @@ def call(self, images: tf.Tensor, training: bool = False) -> tf.Tensor:
:return: a score (logit) for whether it is a high-resolution image, a Tensor of shape (N)
"""
output = self.conv_blocks(images, training=training)
output = self.average_pool(output)
output = self.flatten(output)
output = self.pool(output)
output = self.fc1(output)
output = self.leaky_relu(output)
logit = self.fc2(output) # (N, 1) as Keras retains the last dimension for convenience
Expand Down
16 changes: 8 additions & 8 deletions resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
crop_size = 96 # crop size of target HR images
scaling_factor = 4 # the input LR images will be down-sampled from the target HR images by this factor

resnet = tf.saved_model.load("SuperResolutionResNet_9999")
resnet = tf.saved_model.load("SuperResolutionResNet_99999")
resnet_inference = resnet.signatures["serving_default"]

generator = tf.saved_model.load("Generator_9999")
generator = tf.saved_model.load("Generator_99999")
generator_inference = generator.signatures["serving_default"]

# Need this instance for conversions
Expand Down Expand Up @@ -42,13 +42,12 @@ def super_resolve(img: str, halve: bool = False):
hr_img = hr_img.convert('RGB')

if halve:
hr_img = hr_img.resize((int(hr_img.width / 2), int(hr_img.height / 2)),
Image.LANCZOS)
hr_img = hr_img.resize((int(hr_img.width / 2), int(hr_img.height / 2)), Image.LANCZOS)

lr_img = hr_img.resize((int(hr_img.width / 4), int(hr_img.height / 4)),
Image.BICUBIC)
# Create low resolution image at runtime
lr_img = hr_img.resize((int(hr_img.width / 4), int(hr_img.height / 4)), Image.BICUBIC)

# Bicubic Upsampling
# Bicubic Up-sampling
bicubic_img = lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)

lr_img = tf.expand_dims(transform.convert_image(lr_img, source='pil', target='imagenet-norm'), axis=0)
Expand All @@ -58,6 +57,7 @@ def super_resolve(img: str, halve: bool = False):
sr_img_srresnet = tf.squeeze(sr_img_srresnet['output_0'])
sr_img_srresnet = transform.convert_image(sr_img_srresnet, source='[-1, 1]', target='pil')

# Super-resolution (SR) with SRGAN
sr_img_srgan = generator_inference(lr_img)
sr_img_srgan = tf.squeeze(sr_img_srgan['output_0'])
sr_img_srgan = transform.convert_image(sr_img_srgan, source='[-1, 1]', target='pil')
Expand Down Expand Up @@ -102,4 +102,4 @@ def super_resolve(img: str, halve: bool = False):


if __name__ == '__main__':
super_resolve("bird.jpeg", halve=False)
super_resolve("bird.jpeg", halve=True)
14 changes: 7 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
# Common Model parameters
large_kernel_size = 9 # kernel size of the first and last convolutions which transform the inputs and outputs
small_kernel_size = 3 # kernel size of the first and last convolutions which transform the inputs and outputs
n_channels = 64 # number of channels in-between, input and output channels for residual & subpixel conv blocks
n_blocks = 32 # number of residual blocks
srresnet_checkpoint = "SuperResolutionResNet_9999" # trained SRResNet checkpoint used for initialization
n_channels = 128 # number of channels in-between, input and output channels for residual & subpixel conv blocks
n_blocks = 64 # number of residual blocks
srresnet_checkpoint = "SuperResolutionResNet_99999" # trained SRResNet checkpoint used for generator initialization

# Discriminator parameters
kernel_size_d = 3 # kernel size in all convolutional blocks
n_channels_d = 64 # number of channels in-between, input and output channels for residual & subpixel conv blocks
n_channels_d = 128 # number of channels in-between, input and output channels for residual & subpixel conv blocks
n_blocks_d = 8 # number of convolutional blocks
fc_size_d = 1024 # size of the first fully connected layer

Expand All @@ -36,9 +36,9 @@
checkpoint = None # path to model checkpoint, None if none
batch_size = 1 # batch size
start_epoch = 0 # start at this epoch
epochs = 10_000 # number of training epochs
epochs = 100_000 # number of training epochs
print_freq = 500 # print training status once every __ batches
lr = 1e-5 # learning rate
lr = 1e-6 # learning rate
beta = 1e-3 # the coefficient to weight the adversarial loss in the perceptual loss


Expand Down Expand Up @@ -94,4 +94,4 @@ def main(architecture_type: str = "resnet"):


if __name__ == "__main__":
main(architecture_type="resnet")
main(architecture_type="gan")
2 changes: 1 addition & 1 deletion trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def train(self, start_epoch: int, epochs: int, batch_size: int, print_freq: int
print(f'Epoch: [{epoch}][{i}/{epochs}]----'
f'Loss {loss:.4f}')

if (epoch + 1) % 10_000 == 0:
if (epoch + 1) % 100_000 == 0:
self.save_checkpoint(name=self.architecture.model.__class__.__name__, epoch=epoch)

@staticmethod
Expand Down

0 comments on commit 49e96cc

Please sign in to comment.