Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreiMoraru123 committed Jul 27, 2023
1 parent 8a09410 commit c4b7e5a
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 24 deletions.
7 changes: 6 additions & 1 deletion utils.py → architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image):


class ResNetArchitecture(Architecture):
"""Super Resolution ResNet."""

@tf.function(jit_compile=True)
def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image) -> losses.Loss:
with tf.GradientTape() as tape:
Expand All @@ -68,6 +70,8 @@ def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image)


class GANArchitecture(Architecture):
"""Super Resolution GAN."""

def __init__(
self,
gen_model: Model,
Expand Down Expand Up @@ -104,13 +108,14 @@ def train_step(
low_res_images: Image.Image,
high_res_images: Image.Image,
) -> Tuple[losses.Loss, losses.Loss]:

with tf.GradientTape() as gen_tape:
super_res_images = self.model(low_res_images)
super_res_images = self.transform.convert_image(super_res_images,
source='[-1, 1]',
target='imagenet-norm')
super_res_images_vgg_space = self.vgg(super_res_images)
high_res_images_vgg_space = self.vgg(tf.stop_gradient(high_res_images)) # do not get updated
high_res_images_vgg_space = self.vgg(tf.stop_gradient(high_res_images)) # does not get updated

super_res_discriminated = self.model2(super_res_images)

Expand Down
6 changes: 4 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# third-party imports
import tensorflow as tf # type: ignore
from tensorflow.keras import layers, Model # type: ignore
from tensorflow.keras.applications import VGG19 # type: ignore


class ConvolutionalBlock(layers.Layer):
Expand Down Expand Up @@ -236,6 +237,7 @@ def call(self, low_res_images: tf.Tensor, training: bool = False) -> tf.Tensor:
"""
Forward pass of the Generator
:param training: whether the layer is in training mode or not
:param low_res_images: low-resolution input images, a tensor of size (N, w, h, 3)
:return: super-resolution output images, a tensor of size (N, w * scaling factor, h * scaling factor, 3)
"""
Expand Down Expand Up @@ -315,7 +317,7 @@ def __init__(self, i: int, j: int, **kwargs):
"""
super().__init__(**kwargs)

vgg19 = tf.keras.applications.VGG19(include_top=False)
vgg19 = VGG19(include_top=False)
maxpool_counter = 0
conv_counter = 0
truncate_at = None
Expand All @@ -324,7 +326,7 @@ def __init__(self, i: int, j: int, **kwargs):
if isinstance(layer, layers.Conv2D):
conv_counter += 1
if isinstance(layer, layers.MaxPooling2D):
maxpool_counter +=1
maxpool_counter += 1
conv_counter = 0

# Break if we reach the jth convolution after the (i-1)th max-pool
Expand Down
2 changes: 1 addition & 1 deletion test/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# standard imports
import json
from unittest.mock import patch, MagicMock, PropertyMock
from unittest.mock import patch, MagicMock

# third party imports
import pytest # type: ignore
Expand Down
18 changes: 6 additions & 12 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def __init__(self, shape):


@pytest.fixture(
params=
[
params=[
(32, 64, 3, 1, False, None),
(32, 64, 3, 1, True, 'prelu'),
(32, 64, 5, 1, False, 'leakyrelu'),
Expand All @@ -33,8 +32,7 @@ def conv_block_params(request):


@pytest.fixture(
params=
[
params=[
(3, 64, 2),
(5, 32, 3),
(3, 128, 4),
Expand All @@ -47,8 +45,7 @@ def subpixel_conv_block_params(request):


@pytest.fixture(
params=
[
params=[
(3, 64),
(5, 32),
(3, 128),
Expand All @@ -61,8 +58,7 @@ def residual_block_params(request):


@pytest.fixture(
params=
[
params=[
(9, 3, 64, 16, 2),
(9, 3, 64, 16, 4),
(9, 3, 64, 16, 8),
Expand All @@ -75,8 +71,7 @@ def sr_resnet_params(request):


@pytest.fixture(
params=
[
params=[
(3, 64, 8, 1024),
(3, 32, 6, 512),
(3, 16, 4, 2048),
Expand All @@ -89,8 +84,7 @@ def discriminator_params(request):


@pytest.fixture(
params=
[
params=[
(2, 1),
(3, 2),
(4, 3),
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# module imports
from trainer import Trainer
from transforms import ImageTransform
from utils import ResNetArchitecture, GANArchitecture
from architecture import ResNetArchitecture, GANArchitecture
from model import SuperResolutionResNet, Generator, Discriminator, TruncatedVGG19

load_dotenv()
Expand Down Expand Up @@ -95,4 +95,4 @@ def main(architecture_type: str = "resnet"):


if __name__ == "__main__":
main(architecture_type="gan")
main(architecture_type="resnet")
20 changes: 14 additions & 6 deletions trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
# standard imports
import os
import json
import time
import shutil
import logging

# third-party imports
import tensorflow as tf # type: ignore
Expand All @@ -14,11 +11,12 @@

# module imports
from transforms import ImageTransform
from utils import Architecture, ResNetArchitecture, GANArchitecture
from architecture import Architecture, ResNetArchitecture, GANArchitecture


class Trainer:
"""Utility class to train super resolution models"""
"""Utility class to train super resolution models."""

def __init__(
self,
architecture: Architecture,
Expand Down Expand Up @@ -48,6 +46,7 @@ def __init__(

def compile(self):
"""Compiles the model with the optimizer and loss criterion."""

if isinstance(self.architecture, GANArchitecture):
self.architecture.model.compile(optimizer=self.architecture.optimizer, loss=self.architecture.loss_fn)
self.architecture.model2.compile(optimizer=self.architecture.optimizer2, loss=self.architecture.loss_fn2)
Expand All @@ -57,7 +56,15 @@ def compile(self):
raise NotImplementedError("Trainer not defined for this type of architecture")

def train(self, start_epoch: int, epochs: int, batch_size: int, print_freq: int):
"""Trains the model for the given number of epochs."""
"""
Train the given model architecture.
:param start_epoch: starting epoch
:param epochs: total number of epochs
:param batch_size: how many images the model sees at once
:param print_freq: log stats with this frequency
"""

self.dataset = self.dataset.batch(batch_size=batch_size)
self.dataset = self.dataset.prefetch(tf.data.AUTOTUNE)

Expand Down Expand Up @@ -122,6 +129,7 @@ def create_dataset(

def generator():
"""Data generator for the TensorFlow Dataset."""

for image_path in images:
img = Image.open(image_path, mode='r')
img = img.convert('RGB')
Expand Down

0 comments on commit c4b7e5a

Please sign in to comment.