Skip to content

Commit

Permalink
unified training pipeline refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreiMoraru123 committed Jul 25, 2023
1 parent 3370ac2 commit 66ea1db
Show file tree
Hide file tree
Showing 9 changed files with 387 additions and 304 deletions.
64 changes: 0 additions & 64 deletions dataset.py

This file was deleted.

4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ pytest~=7.3.1
numpy~=1.24.3
pillow~=9.4.0
einops~=0.6.1
python-dotenv~=1.0.0
python-dotenv~=1.0.0
tensorboard~=2.12.3
colorama~=0.4.4
24 changes: 18 additions & 6 deletions test/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# standard imports
import json
from unittest.mock import patch, MagicMock
from unittest.mock import patch, MagicMock, PropertyMock

# third party imports
import pytest # type: ignore
import tensorflow as tf # type: ignore
from PIL import Image # type: ignore

# module imports
from dataset import create_dataset
from trainer import Trainer


@pytest.fixture(name="json_path")
Expand Down Expand Up @@ -37,15 +37,27 @@ def dataset_config(json_path):
"split": "test",
"crop_size": 96,
"scaling_factor": 4,
"lr_img_type": "[0, 255]",
"hr_img_type": "[0, 255]",
"low_res_img_type": "[0, 255]",
"high_res_img_type": "[0, 255]",
"test_data_name": "dummy",
}
return config


@pytest.fixture(name="trainer")
def mock_trainer(config):
"""Mock Trainer with no compilation step and mock architecture."""
mock_compile = patch.object(Trainer, "compile", return_value=None)
mock_compile.start()

instance = Trainer(data_folder=config['data_folder'], architecture=MagicMock())
yield instance
# Optional here, since I don't care if it persists
mock_compile.stop()


@patch("PIL.Image.open")
def test_dataset_creation(mock_img_open, config):
def test_dataset_creation(mock_img_open, trainer, config):
"""Test Dataset creation with mocked paths and image."""

# Mock image object
Expand All @@ -58,7 +70,7 @@ def test_dataset_creation(mock_img_open, config):
# Return the mock image on PIL Image open
mock_img_open.return_value = mock_img
# Create Dataset from config
dataset = create_dataset(**config)
dataset = trainer.create_dataset(**config)
# "image1.jpg", "image2.jpg", "image3.jpg"
assert len(list(dataset.as_numpy_iterator())) == 3
# assert data content
Expand Down
98 changes: 98 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# third-party imports
import tensorflow as tf
from dotenv import load_dotenv

# module imports
from trainer import Trainer
from transforms import ImageTransform
from utils import ResNetArchitecture, GANArchitecture
from model import SuperResolutionResNet, Generator, Discriminator, TruncatedVGG19

load_dotenv()

# Data parameters
data_folder = './' # folder with JSON data files
crop_size = 96 # crop size of target HR images
scaling_factor = 4 # the input LR images will be down-sampled from the target HR images by this factor

# Common Model parameters
large_kernel_size = 9 # kernel size of the first and last convolutions which transform the inputs and outputs
small_kernel_size = 3 # kernel size of the first and last convolutions which transform the inputs and outputs
n_channels = 64 # number of channels in-between, input and output channels for residual & subpixel conv blocks
n_blocks = 16 # number of residual blocks
srresnet_checkpoint = "srresnet" # filepath of the trained SRResNet checkpoint used for initialization

# Discriminator parameters
kernel_size_d = 3 # kernel size in all convolutional blocks
n_channels_d = 64 # number of channels in-between, input and output channels for residual & subpixel conv blocks
n_blocks_d = 8 # number of convolutional blocks
fc_size_d = 1024 # size of the first fully connected layer

# VGG parameters
vgg19_i = 5 # the index i in the definition for VGG loss; see paper or models.py
vgg19_j = 4 # the index j in the definition for VGG loss; see paper or models.py

# Learning parameters
checkpoint = None # path to model checkpoint, None if none
batch_size = 16 # batch size
start_epoch = 0 # start at this epoch
epochs = 50 # number of training epochs
workers = 4 # number of workers for loading data in the DataLoader
print_freq = 500 # print training status once every __ batches
lr = 1e-6 # learning rate
beta = 1e-3 # the coefficient to weight the adversarial loss in the perceptual loss


def main(architecture_type: str = "resnet"):
"""
Manages the whole training pipeline given a model architecture.
:param architecture_type: resnet or gan
"""
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
loss_fn = tf.keras.losses.MeanSquaredError()

if architecture_type == "resnet":
model = SuperResolutionResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)
architecture = ResNetArchitecture(model=model, optimizer=optimizer, loss_fn=loss_fn)

elif architecture_type == "gan":
generator = Generator(large_kernel_size=large_kernel_size,
small_kernel_size=small_kernel_size,
n_channels=n_channels,
n_blocks=n_blocks,
scaling_factor=scaling_factor)

generator.initialize_with_srresnet(srresnet_checkpoint=srresnet_checkpoint)

discriminator = Discriminator(kernel_size=kernel_size_d,
n_channels=n_channels_d,
n_blocks=n_blocks_d,
fc_size=fc_size_d)

adversarial_loss = tf.keras.losses.BinaryCrossentropy()

optimizer_d = tf.keras.optimizers.Adam(learning_rate=lr)

transform = ImageTransform(split="train",
crop_size=crop_size,
lr_img_type='imagenet-norm',
hr_img_type='[-1, 1]',
scaling_factor=scaling_factor)

truncated_vgg19 = TruncatedVGG19(i=vgg19_j, j=vgg19_j)

architecture = GANArchitecture(gen_model=generator, dis_model=discriminator,
gen_optimizer=optimizer, dis_optimizer=optimizer_d,
content_loss=loss_fn, adversarial_loss=adversarial_loss,
transform=transform, vgg=truncated_vgg19)
else:
raise NotImplementedError("Model architecture not implemented")

trainer = Trainer(architecture=architecture, data_folder=data_folder)
trainer.train(start_epoch=start_epoch, epochs=epochs, batch_size=batch_size, print_freq=print_freq)


if __name__ == "__main__":
main(architecture_type="gan")
143 changes: 0 additions & 143 deletions train_srgan.py

This file was deleted.

Loading

0 comments on commit 66ea1db

Please sign in to comment.