-
Notifications
You must be signed in to change notification settings - Fork 6
/
fuse_mechanism.py
43 lines (30 loc) · 1.76 KB
/
fuse_mechanism.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import tensorflow as tf
import tensorflow.contrib.layers as tfc_layers
def fuse_by_concat(input1, input2):
with tf.variable_scope('concat'):
return tf.concat([input1, input2], axis=-1)
def fuse_by_dot_product(input1, input2):
with tf.variable_scope('dot'):
assert input1.get_shape()[1] == input2.get_shape()[1]
return input1 * input2
def fuse_by_brut_force(input1, input2):
with tf.variable_scope('brut_force_fusion'):
assert tf.size(input1) == tf.size(input2) and tf.size(input1) == 2
tf.concat([input1, input2, tf.abs(input1 - input2), input1 * input2], axis=-1)
def fuse_by_vis(input1, input2, projection_size, dropout_keep, activation=tf.nn.tanh, reuse=False):
with tf.variable_scope('vis_fusion'):
input1 = tf.nn.dropout(input1, dropout_keep)
input1_projection = tfc_layers.fully_connected(input1,
num_outputs=projection_size,
activation_fn=activation,
reuse=reuse,
scope="input1_projection")
input2 = tf.nn.dropout(input2, dropout_keep)
input2_projection = tfc_layers.fully_connected(input2,
num_outputs=projection_size,
activation_fn=activation,
reuse=reuse,
scope="input2_projection")
full_projection = input1_projection * input2_projection
full_projection = tf.nn.dropout(full_projection, dropout_keep)
return full_projection