Skip to content

Commit

Permalink
small refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreiMoraru123 committed Jul 27, 2023
1 parent c4b7e5a commit 899f35a
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

# third-party imports
import tensorflow as tf
from tensorflow.keras import losses, Model # type: ignore
from tensorflow.keras import Model # type: ignore
from tensorflow.keras.losses import Loss # type: ignore
from tensorflow.keras.optimizers import Optimizer # type: ignore
from PIL import Image # type: ignore

Expand All @@ -20,10 +21,10 @@ class Architecture(ABC):
def __init__(
self,
model: Model,
loss_fn: losses.Loss,
loss_fn: Loss,
optimizer: Optimizer,
model2: Optional[Model] = None,
loss_fn2: Optional[losses.Loss] = None,
loss_fn2: Optional[Loss] = None,
optimizer2: Optional[Optimizer] = None,
):
"""
Expand Down Expand Up @@ -58,7 +59,7 @@ class ResNetArchitecture(Architecture):
"""Super Resolution ResNet."""

@tf.function(jit_compile=True)
def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image) -> losses.Loss:
def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image) -> Loss:
with tf.GradientTape() as tape:
super_res_images = self.model(low_res_images, training=True)
loss = self.loss_fn(high_res_images, super_res_images)
Expand All @@ -78,8 +79,8 @@ def __init__(
dis_model: Model,
gen_optimizer: Optimizer,
dis_optimizer: Optimizer,
content_loss: losses.Loss,
adversarial_loss: losses.Loss,
content_loss: Loss,
adversarial_loss: Loss,
transform: ImageTransform,
vgg: Model,
beta: float = 1e-3,
Expand All @@ -103,12 +104,7 @@ def __init__(
self.beta = beta

@tf.function(jit_compile=True)
def train_step(
self,
low_res_images: Image.Image,
high_res_images: Image.Image,
) -> Tuple[losses.Loss, losses.Loss]:

def train_step(self, low_res_images: Image.Image, high_res_images: Image.Image) -> Tuple[Loss, Loss]:
with tf.GradientTape() as gen_tape:
super_res_images = self.model(low_res_images)
super_res_images = self.transform.convert_image(super_res_images,
Expand All @@ -129,8 +125,10 @@ def train_step(
with tf.GradientTape() as dis_tape:
super_res_discriminated = self.model2(tf.stop_gradient(super_res_images)) # re-evaluate without gradients
high_res_discriminated = self.model2(high_res_images)
adversarial_loss = self.loss_fn2(super_res_discriminated, tf.zeros_like(super_res_discriminated)) + \
self.loss_fn2(high_res_discriminated, tf.ones_like(high_res_discriminated))

adversarial_loss = \
self.loss_fn2(super_res_discriminated, tf.zeros_like(super_res_discriminated)) + \
self.loss_fn2(high_res_discriminated, tf.ones_like(high_res_discriminated))

dis_gradients = dis_tape.gradient(adversarial_loss, self.model2.trainable_variables)
self.optimizer2.apply_gradients(zip(dis_gradients, self.model2.trainable_variables))
Expand Down

0 comments on commit 899f35a

Please sign in to comment.