-
Notifications
You must be signed in to change notification settings - Fork 6
/
cbn_pluggin.py
93 lines (70 loc) · 3.7 KB
/
cbn_pluggin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import tensorflow as tf
import tensorflow.contrib.layers as tfc_layers
class CBNAbtract(object):
"""
Factory (Design pattern) to use cbn
"""
def create_cbn_input(self, feature_maps):
"""
This method is called every time conditional batchnorm is applied on cbn
This factory enable to inject cbn to a pretrained resnet
The good practice is to put the input of cbn (lstm embedding for instance) in the constructor.
One may then use this variable in the create cbn.
e.g.
def __init__(self, lstm_state):
self.lstm_state = lstm_state
def create_cbn_input(feature_map):
feat = int(feature_maps.get_shape()[3])
delta_betas = tf.contrib.layers.fully_connected(lstm_state, num_outputs=feat)
delta_gammas = tf.contrib.layers.fully_connected(lstm_state, num_outputs=feat)
return delta_betas, delta_gammas
:param feature_maps: (None,h,w,f)
:return: deltas_betas, delta_gammas: (None, f), (None, f)
"""
batch_size = int(feature_maps.get_shape()[0])
heigh = int(feature_maps.get_shape()[1])
width = int(feature_maps.get_shape()[2])
feat = int(feature_maps.get_shape()[3])
delta_betas = tf.zeros(shape=[batch_size, feat]) # Note that this does not compile (batch_size=None)
delta_gammas = tf.zeros(shape=[batch_size, feat])
return delta_betas, delta_gammas
class CBNfromLSTM(CBNAbtract):
"""
Basic LSTM for CBN
"""
def __init__(self, lstm_state, no_units, dropout_keep=1.0, use_betas=True, use_gammas=True):
self.lstm_state = lstm_state
self.cbn_embedding_size = no_units
self.use_betas = use_betas
self.use_gammas = use_gammas
self.dropout_keep = dropout_keep
def create_cbn_input(self, feature_maps):
no_features = int(feature_maps.get_shape()[3])
batch_size = tf.shape(feature_maps)[0]
if self.use_betas:
h_betas = tfc_layers.fully_connected(self.lstm_state,
num_outputs=self.cbn_embedding_size,
activation_fn=tf.nn.relu,
scope="hidden_betas")
h_betas = tf.nn.dropout(h_betas, self.dropout_keep)
delta_betas = tfc_layers.fully_connected(h_betas,
num_outputs=no_features,
activation_fn=None,
biases_initializer=None,
scope="delta_beta")
else:
delta_betas = tf.tile(tf.constant(0.0, shape=[1, no_features]), tf.stack([batch_size, 1]))
if self.use_gammas:
h_gammas = tfc_layers.fully_connected(self.lstm_state,
num_outputs=self.cbn_embedding_size,
activation_fn=tf.nn.relu,
scope="hidden_gammas")
h_gammas = tf.nn.dropout(h_gammas, self.dropout_keep)
delta_gammas = tfc_layers.fully_connected(h_gammas,
num_outputs=no_features,
activation_fn=None,
biases_initializer=None,
scope="delta_gamma")
else:
delta_gammas = tf.tile(tf.constant(0.0, shape=[1, no_features]), tf.stack([batch_size, 1]))
return delta_betas, delta_gammas