From c4b7e5ab55bbf51febfb684d372f8e2f8c08761b Mon Sep 17 00:00:00 2001 From: AndreiMoraru123 <81184255+AndreiMoraru123@users.noreply.github.com> Date: Thu, 27 Jul 2023 22:24:27 +0300 Subject: [PATCH] lint --- utils.py => architecture.py | 7 ++++++- model.py | 6 ++++-- test/test_dataset.py | 2 +- test/test_model.py | 18 ++++++------------ train.py | 4 ++-- trainer.py | 20 ++++++++++++++------ 6 files changed, 33 insertions(+), 24 deletions(-) rename utils.py => architecture.py (97%) diff --git a/utils.py b/architecture.py similarity index 97% rename from utils.py rename to architecture.py index 646787b..4db18e7 100644 --- a/utils.py +++ b/architecture.py @@ -55,6 +55,8 @@ def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image): class ResNetArchitecture(Architecture): + """Super Resolution ResNet.""" + @tf.function(jit_compile=True) def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image) -> losses.Loss: with tf.GradientTape() as tape: @@ -68,6 +70,8 @@ def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image) class GANArchitecture(Architecture): + """Super Resolution GAN.""" + def __init__( self, gen_model: Model, @@ -104,13 +108,14 @@ def train_step( low_res_images: Image.Image, high_res_images: Image.Image, ) -> Tuple[losses.Loss, losses.Loss]: + with tf.GradientTape() as gen_tape: super_res_images = self.model(low_res_images) super_res_images = self.transform.convert_image(super_res_images, source='[-1, 1]', target='imagenet-norm') super_res_images_vgg_space = self.vgg(super_res_images) - high_res_images_vgg_space = self.vgg(tf.stop_gradient(high_res_images)) # do not get updated + high_res_images_vgg_space = self.vgg(tf.stop_gradient(high_res_images)) # does not get updated super_res_discriminated = self.model2(super_res_images) diff --git a/model.py b/model.py index ed7233e..fdee698 100644 --- a/model.py +++ b/model.py @@ -5,6 +5,7 @@ # third-party imports import tensorflow as tf # type: ignore from tensorflow.keras import layers, Model # type: ignore +from tensorflow.keras.applications import VGG19 # type: ignore class ConvolutionalBlock(layers.Layer): @@ -236,6 +237,7 @@ def call(self, low_res_images: tf.Tensor, training: bool = False) -> tf.Tensor: """ Forward pass of the Generator + :param training: whether the layer is in training mode or not :param low_res_images: low-resolution input images, a tensor of size (N, w, h, 3) :return: super-resolution output images, a tensor of size (N, w * scaling factor, h * scaling factor, 3) """ @@ -315,7 +317,7 @@ def __init__(self, i: int, j: int, **kwargs): """ super().__init__(**kwargs) - vgg19 = tf.keras.applications.VGG19(include_top=False) + vgg19 = VGG19(include_top=False) maxpool_counter = 0 conv_counter = 0 truncate_at = None @@ -324,7 +326,7 @@ def __init__(self, i: int, j: int, **kwargs): if isinstance(layer, layers.Conv2D): conv_counter += 1 if isinstance(layer, layers.MaxPooling2D): - maxpool_counter +=1 + maxpool_counter += 1 conv_counter = 0 # Break if we reach the jth convolution after the (i-1)th max-pool diff --git a/test/test_dataset.py b/test/test_dataset.py index 41f0197..e179576 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -1,6 +1,6 @@ # standard imports import json -from unittest.mock import patch, MagicMock, PropertyMock +from unittest.mock import patch, MagicMock # third party imports import pytest # type: ignore diff --git a/test/test_model.py b/test/test_model.py index 0c36d6a..7903822 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -19,8 +19,7 @@ def __init__(self, shape): @pytest.fixture( - params= - [ + params=[ (32, 64, 3, 1, False, None), (32, 64, 3, 1, True, 'prelu'), (32, 64, 5, 1, False, 'leakyrelu'), @@ -33,8 +32,7 @@ def conv_block_params(request): @pytest.fixture( - params= - [ + params=[ (3, 64, 2), (5, 32, 3), (3, 128, 4), @@ -47,8 +45,7 @@ def subpixel_conv_block_params(request): @pytest.fixture( - params= - [ + params=[ (3, 64), (5, 32), (3, 128), @@ -61,8 +58,7 @@ def residual_block_params(request): @pytest.fixture( - params= - [ + params=[ (9, 3, 64, 16, 2), (9, 3, 64, 16, 4), (9, 3, 64, 16, 8), @@ -75,8 +71,7 @@ def sr_resnet_params(request): @pytest.fixture( - params= - [ + params=[ (3, 64, 8, 1024), (3, 32, 6, 512), (3, 16, 4, 2048), @@ -89,8 +84,7 @@ def discriminator_params(request): @pytest.fixture( - params= - [ + params=[ (2, 1), (3, 2), (4, 3), diff --git a/train.py b/train.py index 63342eb..2aad75d 100644 --- a/train.py +++ b/train.py @@ -5,7 +5,7 @@ # module imports from trainer import Trainer from transforms import ImageTransform -from utils import ResNetArchitecture, GANArchitecture +from architecture import ResNetArchitecture, GANArchitecture from model import SuperResolutionResNet, Generator, Discriminator, TruncatedVGG19 load_dotenv() @@ -95,4 +95,4 @@ def main(architecture_type: str = "resnet"): if __name__ == "__main__": - main(architecture_type="gan") + main(architecture_type="resnet") diff --git a/trainer.py b/trainer.py index 6eb1fba..d7620f8 100644 --- a/trainer.py +++ b/trainer.py @@ -1,9 +1,6 @@ # standard imports import os import json -import time -import shutil -import logging # third-party imports import tensorflow as tf # type: ignore @@ -14,11 +11,12 @@ # module imports from transforms import ImageTransform -from utils import Architecture, ResNetArchitecture, GANArchitecture +from architecture import Architecture, ResNetArchitecture, GANArchitecture class Trainer: - """Utility class to train super resolution models""" + """Utility class to train super resolution models.""" + def __init__( self, architecture: Architecture, @@ -48,6 +46,7 @@ def __init__( def compile(self): """Compiles the model with the optimizer and loss criterion.""" + if isinstance(self.architecture, GANArchitecture): self.architecture.model.compile(optimizer=self.architecture.optimizer, loss=self.architecture.loss_fn) self.architecture.model2.compile(optimizer=self.architecture.optimizer2, loss=self.architecture.loss_fn2) @@ -57,7 +56,15 @@ def compile(self): raise NotImplementedError("Trainer not defined for this type of architecture") def train(self, start_epoch: int, epochs: int, batch_size: int, print_freq: int): - """Trains the model for the given number of epochs.""" + """ + Train the given model architecture. + + :param start_epoch: starting epoch + :param epochs: total number of epochs + :param batch_size: how many images the model sees at once + :param print_freq: log stats with this frequency + """ + self.dataset = self.dataset.batch(batch_size=batch_size) self.dataset = self.dataset.prefetch(tf.data.AUTOTUNE) @@ -122,6 +129,7 @@ def create_dataset( def generator(): """Data generator for the TensorFlow Dataset.""" + for image_path in images: img = Image.open(image_path, mode='r') img = img.convert('RGB')