-
Notifications
You must be signed in to change notification settings - Fork 9
/
losses.py
62 lines (54 loc) · 2.53 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow as tf
import numpy as np
from tensorflow.keras import losses
# Add perceptual loss
def disc_hinge(dis_real, dis_fake):
real_loss = -1.0 * tf.reduce_mean( tf.minimum(0.0, -1.0 + dis_real) )
fake_loss = -1.0 * tf.reduce_mean( tf.minimum(0.0, -1.0 - dis_fake) )
return (real_loss + fake_loss)/2.0
def gen_hinge(dis_fake):
fake_loss = -1.0 * tf.reduce_mean( dis_fake )
return fake_loss
def disc_loss(dis_real, dis_fake, dis_wrong=None):
# real = tf.ones_like(dis_real)
# fake = tf.zeros_like(dis_fake)
real = tf.convert_to_tensor(np.random.randint(low=7, high=12, size=dis_real.shape)/10.0)
fake = tf.convert_to_tensor(np.random.randint(low=0, high=3, size=dis_real.shape)/10.0)
real_loss = losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(real, dis_real)
fake_loss = losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(fake, dis_fake)
#wrong_loss = losses.BinaryCrossentropy()(fake, dis_wrong)
#total_loss = (real_loss + fake_loss + wrong_loss)/3.0
# total_loss = tf.reduce_mean(real_loss**2 + fake_loss**2)
total_loss = (real_loss + fake_loss)/2.0
return total_loss
def gen_loss(dis_fake):
real = tf.ones_like(dis_fake)
fake_loss = losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(real, dis_fake)
return fake_loss
def critic_loss(D_real, D_fake):
# real = -tf.ones_like(D_real)
# fake = tf.ones_like(D_fake)
# return tf.reduce_mean(D_real*real) + tf.reduce_mean(D_fake*fake)
return (tf.reduce_mean(D_fake) - tf.reduce_mean(D_real))
# def gen_loss(D_fake, real_img, fake_img):
# # fake = tf.ones_like(D_fake)
# # return tf.reduce_mean(D_fake*fake)
# # real_img = tf.clip_by_value(255.0*(real_img*0.5+0.5), 0.0, 255.0)
# # fake_img = tf.clip_by_value(255.0*(fake_img*0.5+0.5), 0.0, 255.0)
# return -tf.reduce_mean(D_fake) #+ 0.6 * tf.keras.losses.MeanSquaredError()(real_img, fake_img)
def wgan_gp_loss(D_real, D_fake, Y, Y_cap, model, batch_size):
dloss = (tf.reduce_mean(D_fake) - tf.reduce_mean(D_real))
lam = 10
eps = tf.random.uniform(shape=[batch_size, 1, 1, 1], minval=0, maxval=1)
x_cap = eps * Y + (1-eps) * Y_cap
with tf.GradientTape() as gptape:
gptape.watch(x_cap)
out = model.critic(x_cap, training=True)
grad = gptape.gradient(out, [x_cap])[0]
# Fetching only x-gradient
# grad_norm = tf.norm(grad, ord='euclidean', axis=1)
# grad_pen = tf.reduce_mean(tf.square(grad_norm - 1))
grad_norm = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1, 2, 3]))
grad_pen = tf.reduce_mean((grad_norm - 1.0) ** 2)
dloss = dloss + lam * grad_pen
return dloss