Skip to content

Commit

Permalink
logging training
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreiMoraru123 committed Jul 30, 2023
1 parent 49e96cc commit 0805469
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 72 deletions.
72 changes: 48 additions & 24 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
load_dotenv()

# Data parameters
data_folder = './' # folder with JSON data files
data_folder = "./" # folder with JSON data files
crop_size = 96 # crop size of target HR images
scaling_factor = 4 # the input LR images will be down-sampled from the target HR images by this factor

Expand Down Expand Up @@ -52,46 +52,70 @@ def main(architecture_type: str = "resnet"):
loss_fn = tf.keras.losses.MeanSquaredError()

if architecture_type == "resnet":
model = SuperResolutionResNet(large_kernel_size=large_kernel_size, small_kernel_size=small_kernel_size,
n_channels=n_channels, n_blocks=n_blocks, scaling_factor=scaling_factor)
architecture = ResNetArchitecture(model=model, optimizer=optimizer, loss_fn=loss_fn)
model = SuperResolutionResNet(
large_kernel_size=large_kernel_size,
small_kernel_size=small_kernel_size,
n_channels=n_channels,
n_blocks=n_blocks,
scaling_factor=scaling_factor,
)
architecture = ResNetArchitecture(
model=model, optimizer=optimizer, loss_fn=loss_fn
)

elif architecture_type == "gan":
generator = Generator(large_kernel_size=large_kernel_size,
small_kernel_size=small_kernel_size,
n_channels=n_channels,
n_blocks=n_blocks,
scaling_factor=scaling_factor)
generator = Generator(
large_kernel_size=large_kernel_size,
small_kernel_size=small_kernel_size,
n_channels=n_channels,
n_blocks=n_blocks,
scaling_factor=scaling_factor,
)

generator.initialize_with_srresnet(srresnet_checkpoint=srresnet_checkpoint)

discriminator = Discriminator(kernel_size=kernel_size_d,
n_channels=n_channels_d,
n_blocks=n_blocks_d,
fc_size=fc_size_d)
discriminator = Discriminator(
kernel_size=kernel_size_d,
n_channels=n_channels_d,
n_blocks=n_blocks_d,
fc_size=fc_size_d,
)

adversarial_loss = tf.keras.losses.BinaryCrossentropy()

optimizer_d = tf.keras.optimizers.Adam(learning_rate=lr)

transform = ImageTransform(split="train",
crop_size=crop_size,
lr_img_type='imagenet-norm',
hr_img_type='[-1, 1]',
scaling_factor=scaling_factor)
transform = ImageTransform(
split="train",
crop_size=crop_size,
lr_img_type="imagenet-norm",
hr_img_type="[-1, 1]",
scaling_factor=scaling_factor,
)

truncated_vgg19 = TruncatedVGG19(i=vgg19_j, j=vgg19_j)

architecture = GANArchitecture(gen_model=generator, dis_model=discriminator,
gen_optimizer=optimizer, dis_optimizer=optimizer_d,
content_loss=loss_fn, adversarial_loss=adversarial_loss,
transform=transform, vgg=truncated_vgg19)
architecture = GANArchitecture(
gen_model=generator,
dis_model=discriminator,
gen_optimizer=optimizer,
dis_optimizer=optimizer_d,
content_loss=loss_fn,
adversarial_loss=adversarial_loss,
transform=transform,
vgg=truncated_vgg19,
)
else:
raise NotImplementedError("Model architecture not implemented")

trainer = Trainer(architecture=architecture, data_folder=data_folder)
trainer.train(start_epoch=start_epoch, epochs=epochs, batch_size=batch_size, print_freq=print_freq)
trainer.train(
start_epoch=start_epoch,
epochs=epochs,
batch_size=batch_size,
print_freq=print_freq,
)


if __name__ == "__main__":
main(architecture_type="gan")
main(architecture_type="resnet")
156 changes: 108 additions & 48 deletions trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# standard imports
import os
import json
import shutil
import logging

# third-party imports
import tensorflow as tf # type: ignore
Expand All @@ -13,6 +15,10 @@
from transforms import ImageTransform
from architecture import Architecture, ResNetArchitecture, GANArchitecture

# logging house-keeping
init(autoreset=True)
logging.basicConfig(level=logging.INFO)


class Trainer:
"""Utility class to train super resolution models."""
Expand All @@ -23,37 +29,66 @@ def __init__(
data_folder: str,
crop_size: int = 96,
scaling_factor: int = 4,
low_res_image_type: str = 'imagenet-norm',
high_res_image_type: str = '[-1, 1]'
low_res_image_type: str = "imagenet-norm",
high_res_image_type: str = "[-1, 1]",
log_dir: str = "logs",
):
"""
Initializes the trainer with the given architecture.
:param architecture: Architecture (model + optimizer + loss)
:param data_folder: folder in which the data is stor
:param crop_size: cropping size for transforms during training
:param scaling_factor: up-scaling factor for higher resolution
:param low_res_image_type: low resolution image type for transform
:param high_res_image_type: high resolution image type for transform
:param log_dir: directory location for TensorBoard logging
"""
self.architecture = architecture
self.data_folder = data_folder
self.crop_size = crop_size
self.scaling_factor = scaling_factor
self.low_res_image_type = low_res_image_type
self.high_res_image_type = high_res_image_type
self.dataset = self.create_dataset(data_folder=data_folder, crop_size=crop_size,
high_res_img_type=high_res_image_type,
low_res_img_type=low_res_image_type,
scaling_factor=scaling_factor,
split="train")
self.log_dir = log_dir

if os.path.exists(log_dir):
logging.info(f"{Fore.YELLOW}Flushing Logs")
shutil.rmtree(log_dir)

logging.info(f"{Fore.CYAN}Creating Summary Writer")
self.summary_writer = tf.summary.create_file_writer(log_dir)

logging.info(f"{Fore.MAGENTA}Creating Dataset")
self.dataset = self.create_dataset(
data_folder=data_folder,
crop_size=crop_size,
high_res_img_type=high_res_image_type,
low_res_img_type=low_res_image_type,
scaling_factor=scaling_factor,
split="train",
)
logging.info(f"{Fore.GREEN}Compiling Model")
self.compile()

def compile(self):
"""Compiles the model with the optimizer and loss criterion."""

if isinstance(self.architecture, GANArchitecture):
self.architecture.model.compile(optimizer=self.architecture.optimizer, loss=self.architecture.loss_fn)
self.architecture.model2.compile(optimizer=self.architecture.optimizer2, loss=self.architecture.loss_fn2)
self.architecture.model.compile(
optimizer=self.architecture.optimizer, loss=self.architecture.loss_fn
)
self.architecture.model2.compile(
optimizer=self.architecture.optimizer2, loss=self.architecture.loss_fn2
)
elif isinstance(self.architecture, ResNetArchitecture):
self.architecture.model.compile(optimizer=self.architecture.optimizer, loss=self.architecture.loss_fn)
self.architecture.model.compile(
optimizer=self.architecture.optimizer, loss=self.architecture.loss_fn
)
else:
raise NotImplementedError("Trainer not defined for this type of architecture")
raise NotImplementedError(
"Trainer not defined for this type of architecture"
)

def save_checkpoint(self, name: str, epoch: int):
"""
Expand All @@ -63,7 +98,11 @@ def save_checkpoint(self, name: str, epoch: int):
:param epoch: the given epoch for which to save the model.
"""

@tf.function(input_signature=[tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32)])
@tf.function(
input_signature=[
tf.TensorSpec(shape=[None, None, None, 3], dtype=tf.float32)
]
)
def serving_fn(image_batch):
"""
Serving function for saving the model.
Expand All @@ -73,11 +112,13 @@ def serving_fn(image_batch):
"""
return self.architecture.model(image_batch)

tf.saved_model.save(self.architecture.model,
export_dir=f"{name}_{epoch}/",
signatures=serving_fn)
tf.saved_model.save(
self.architecture.model,
export_dir=f"{name}_{epoch}/",
signatures=serving_fn,
)

def train(self, start_epoch: int, epochs: int, batch_size: int, print_freq: int):
def train(self, start_epoch: int, epochs: int, batch_size: int, print_freq: int):
"""
Train the given model architecture.
Expand All @@ -95,22 +136,33 @@ def train(self, start_epoch: int, epochs: int, batch_size: int, print_freq: int
low_res_images = tf.dtypes.cast(low_res_images, tf.float32)
high_res_imgs = tf.dtypes.cast(high_res_imgs, tf.float32)

loss = self.architecture.train_step(low_res_images=low_res_images,
high_res_images=high_res_imgs)

if isinstance(loss, tuple):
gen_loss, dis_loss = loss
if i % print_freq == 0:
print(f'Epoch: [{epoch}][{i}/{epochs}]----'
f'Generator Loss {gen_loss:.4f}----'
f'Discriminator Loss {dis_loss:.4f}')
else:
if i % print_freq == 0:
print(f'Epoch: [{epoch}][{i}/{epochs}]----'
f'Loss {loss:.4f}')
loss = self.architecture.train_step(
low_res_images=low_res_images, high_res_images=high_res_imgs
)

with self.summary_writer.as_default():
if isinstance(loss, tuple):
gen_loss, dis_loss = loss
if i % print_freq == 0:
logging.info(
f"{Fore.GREEN}Epoch: [{epoch}][{i}/{epochs}]----"
f"{Fore.YELLOW}Generator Loss {gen_loss:.4f}----"
f"{Fore.CYAN}Discriminator Loss {dis_loss:.4f}"
)
tf.summary.scalar("Generator Loss", gen_loss, step=i)
tf.summary.scalar("Discriminator Loss", dis_loss, step=i)
else:
if i % print_freq == 0:
logging.info(
f"{Fore.GREEN}Epoch: [{epoch}][{i}/{epochs}]----"
f"{Fore.BLUE}Loss {loss:.4f}"
)
tf.summary.scalar("Loss", loss, step=i)

if (epoch + 1) % 100_000 == 0:
self.save_checkpoint(name=self.architecture.model.__class__.__name__, epoch=epoch)
self.save_checkpoint(
name=self.architecture.model.__class__.__name__, epoch=epoch
)

@staticmethod
def create_dataset(
Expand All @@ -120,7 +172,7 @@ def create_dataset(
scaling_factor: int,
low_res_img_type: str,
high_res_img_type: str,
test_data_name: str = '',
test_data_name: str = "",
) -> tf.data.Dataset:
"""
Create a Super Resolution (SR) dataset using TensorFlow's data API.
Expand All @@ -133,36 +185,44 @@ def create_dataset(
:param high_res_img_type: the format for the HR image supplied to the model
:param test_data_name: if this is the 'test' split, which test dataset? (for example, "Set14")
"""
assert split in {'train', 'test'}
if split == 'test' and not test_data_name:
assert split in {"train", "test"}
if split == "test" and not test_data_name:
raise ValueError("Please provide the name of the test dataset!")
assert low_res_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm'}
assert high_res_img_type in {'[0, 255]', '[0, 1]', '[-1, 1]', 'imagenet-norm'}
assert low_res_img_type in {"[0, 255]", "[0, 1]", "[-1, 1]", "imagenet-norm"}
assert high_res_img_type in {"[0, 255]", "[0, 1]", "[-1, 1]", "imagenet-norm"}

if split == 'train':
with open(os.path.join(data_folder, 'train_images.json'), 'r') as f:
if split == "train":
with open(os.path.join(data_folder, "train_images.json"), "r") as f:
images = json.load(f)
else:
with open(os.path.join(data_folder, test_data_name + '_test_images.json'), 'r') as f:
with open(
os.path.join(data_folder, test_data_name + "_test_images.json"), "r"
) as f:
images = json.load(f)

transform = ImageTransform(split=split,
crop_size=crop_size,
lr_img_type=low_res_img_type,
hr_img_type=high_res_img_type,
scaling_factor=scaling_factor)
transform = ImageTransform(
split=split,
crop_size=crop_size,
lr_img_type=low_res_img_type,
hr_img_type=high_res_img_type,
scaling_factor=scaling_factor,
)

def generator():
"""Data generator for the TensorFlow Dataset."""

for image_path in images:
img = Image.open(image_path, mode='r')
img = img.convert('RGB')
img = Image.open(image_path, mode="r")
img = img.convert("RGB")
# Transform
lr_img, hr_img = transform(img)
# Generate
yield lr_img, hr_img

return tf.data.Dataset.from_generator(generator=generator, output_signature=(
tf.TensorSpec(shape=(None, None, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, None, 3), dtype=tf.float32)))
return tf.data.Dataset.from_generator(
generator=generator,
output_signature=(
tf.TensorSpec(shape=(None, None, 3), dtype=tf.float32),
tf.TensorSpec(shape=(None, None, 3), dtype=tf.float32),
),
)

0 comments on commit 0805469

Please sign in to comment.