Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreiMoraru123 committed Jul 30, 2023
1 parent 24c8be6 commit 3127862
Show file tree
Hide file tree
Showing 11 changed files with 384 additions and 171 deletions.
66 changes: 47 additions & 19 deletions architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ class ResNetArchitecture(Architecture):
"""Super Resolution ResNet."""

@tf.function(jit_compile=True)
def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image) -> Loss:
def train_step(
self, low_res_images: Image.Image, high_res_images: Image.Image
) -> Loss:
with tf.GradientTape() as tape:
super_res_images = self.model(low_res_images, training=True)
loss = self.loss_fn(high_res_images, super_res_images)
Expand Down Expand Up @@ -97,40 +99,66 @@ def __init__(
:param vgg: Optional truncated VGG19 to project the predictions into a dimension where the loss makes more sense
:param beta: the coefficient to weight the adversarial loss in the perceptual loss
"""
super().__init__(model=gen_model, loss_fn=content_loss, optimizer=gen_optimizer,
model2=dis_model, loss_fn2=adversarial_loss, optimizer2=dis_optimizer)
super().__init__(
model=gen_model,
loss_fn=content_loss,
optimizer=gen_optimizer,
model2=dis_model,
loss_fn2=adversarial_loss,
optimizer2=dis_optimizer,
)
self.transform = transform
self.vgg = vgg
self.beta = beta

@tf.function(jit_compile=True)
def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image) -> Tuple[Loss, Loss]:
def train_step(
self, low_res_images: Image.Image, high_res_images: Image.Image
) -> Tuple[Loss, Loss]:
with tf.GradientTape() as gen_tape:
super_res_images = self.model(low_res_images, training=True)
super_res_images = self.transform.convert_image(super_res_images,
source='[-1, 1]',
target='imagenet-norm')
super_res_images = self.transform.convert_image(
super_res_images, source="[-1, 1]", target="imagenet-norm"
)
super_res_images_vgg_space = self.vgg(super_res_images)
high_res_images_vgg_space = self.vgg(tf.stop_gradient(high_res_images)) # does not get updated
high_res_images_vgg_space = self.vgg(
tf.stop_gradient(high_res_images)
) # does not get updated

super_res_discriminated = self.model2(super_res_images, training=True)

content_loss = self.loss_fn(super_res_images_vgg_space, high_res_images_vgg_space)
adversarial_loss = self.loss_fn2(super_res_discriminated, tf.ones_like(super_res_discriminated))
content_loss = self.loss_fn(
super_res_images_vgg_space, high_res_images_vgg_space
)
adversarial_loss = self.loss_fn2(
super_res_discriminated, tf.ones_like(super_res_discriminated)
)
perceptual_loss = content_loss + self.beta * adversarial_loss

gen_gradients = gen_tape.gradient(perceptual_loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gen_gradients, self.model.trainable_variables))
gen_gradients = gen_tape.gradient(
perceptual_loss, self.model.trainable_variables
)
self.optimizer.apply_gradients(
zip(gen_gradients, self.model.trainable_variables)
)

with tf.GradientTape() as dis_tape:
super_res_discriminated = self.model2(tf.stop_gradient(super_res_images), training=True)
super_res_discriminated = self.model2(
tf.stop_gradient(super_res_images), training=True
)
high_res_discriminated = self.model2(high_res_images, training=True)

adversarial_loss = \
self.loss_fn2(super_res_discriminated, tf.zeros_like(super_res_discriminated)) + \
self.loss_fn2(high_res_discriminated, tf.ones_like(high_res_discriminated))

dis_gradients = dis_tape.gradient(adversarial_loss, self.model2.trainable_variables)
self.optimizer2.apply_gradients(zip(dis_gradients, self.model2.trainable_variables))
adversarial_loss = self.loss_fn2(
super_res_discriminated, tf.zeros_like(super_res_discriminated)
) + self.loss_fn2(
high_res_discriminated, tf.ones_like(high_res_discriminated)
)

dis_gradients = dis_tape.gradient(
adversarial_loss, self.model2.trainable_variables
)
self.optimizer2.apply_gradients(
zip(dis_gradients, self.model2.trainable_variables)
)

return perceptual_loss, adversarial_loss
41 changes: 26 additions & 15 deletions create_data_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from PIL import Image # type: ignore


def create_data_lists(train_folders: List[str], test_folders: List[str], min_size: int, output_folder: str):
def create_data_lists(
train_folders: List[str], test_folders: List[str], min_size: int, output_folder: str
):
"""
Create lists for images in the training set and each of the test sets.
Expand All @@ -21,33 +23,42 @@ def create_data_lists(train_folders: List[str], test_folders: List[str], min_siz
for d in train_folders:
for i in os.listdir(d):
img_path = os.path.join(d, i)
img = Image.open(img_path, mode='r')
img = Image.open(img_path, mode="r")
if img.width >= min_size and img.height >= min_size:
train_images.append(img_path)
print("There are %d images in the training data.\n" % len(train_images))
with open(os.path.join(output_folder, 'train_images.json'), 'w') as j:
with open(os.path.join(output_folder, "train_images.json"), "w") as j:
json.dump(train_images, j)

for d in test_folders:
test_images = list()
test_name = d.split("/")[-1]
for i in os.listdir(d):
img_path = os.path.join(d, i)
img = Image.open(img_path, mode='r')
img = Image.open(img_path, mode="r")
if img.width >= min_size and img.height >= min_size:
test_images.append(img_path)
print("There are %d images in the %s test data.\n" % (len(test_images), test_name))
with open(os.path.join(output_folder, test_name + '_test_images.json'), 'w') as j:
print(
"There are %d images in the %s test data.\n" % (len(test_images), test_name)
)
with open(
os.path.join(output_folder, test_name + "_test_images.json"), "w"
) as j:
json.dump(test_images, j)

print("JSONS containing lists of Train and Test images have been saved to %s\n" % output_folder)
print(
"JSONS containing lists of Train and Test images have been saved to %s\n"
% output_folder
)


if __name__ == '__main__':
create_data_lists(train_folders=[r'D:\WatchAndTellCuda\coco\images\train2017',
r'D:\WatchAndTellCuda\coco\images\val2017'],
test_folders=['SR/BSDS100',
'SR/Set5',
'SR/Set14'],
min_size=100,
output_folder='./')
if __name__ == "__main__":
create_data_lists(
train_folders=[
r"D:\WatchAndTellCuda\coco\images\train2017",
r"D:\WatchAndTellCuda\coco\images\val2017",
],
test_folders=["SR/BSDS100", "SR/Set5", "SR/Set14"],
min_size=100,
output_folder="./",
)
14 changes: 10 additions & 4 deletions evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
scaling_factor: int = 4,
low_res_image_type: str = "imagenet-norm",
high_res_image_type: str = "[-1, 1]",
test_data_name: str = "dummy"
test_data_name: str = "dummy",
):
"""
:param resnet: the SRResNet TF model to be evaluated
Expand Down Expand Up @@ -67,8 +67,12 @@ def evaluate(self):
"""evaluates the model using peak signal-to-noise ratio and structural similarity."""

for _, (low_res_images, high_res_images) in enumerate(self.dataset):
super_res_images_resnet = self.resnet_inference(tf.expand_dims(low_res_images, axis=0))['output_0']
super_res_images_srgan = self.generator_inference(tf.expand_dims(low_res_images, axis=0))['output_0']
super_res_images_resnet = self.resnet_inference(
tf.expand_dims(low_res_images, axis=0)
)["output_0"]
super_res_images_srgan = self.generator_inference(
tf.expand_dims(low_res_images, axis=0)
)["output_0"]

super_res_images_resnet_y = self.transform.convert_image(
super_res_images_resnet, source="[-1, 1]", target="y-channel"
Expand All @@ -81,7 +85,9 @@ def evaluate(self):
super_res_images_srgan_y = tf.squeeze(super_res_images_srgan_y, axis=0)

high_res_images_y = self.transform.convert_image(
tf.expand_dims(high_res_images, axis=0), source="[-1, 1]", target="y-channel"
tf.expand_dims(high_res_images, axis=0),
source="[-1, 1]",
target="y-channel",
)
high_res_images_y = tf.squeeze(high_res_images_y, axis=0)

Expand Down
Loading

0 comments on commit 3127862

Please sign in to comment.