forked from nimingniming/gdn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
diffusionfeatures.py
43 lines (35 loc) · 1.3 KB
/
diffusionfeatures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import tensorflow as tf
class DiffuseFeatures(layers.Layer):
"""Utility layer calculating a single channel of the
diffusional convolution.
"""
def __init__(
self,
num_diffusion_steps: int,
kernel_initializer,
kernel_regularizer,
kernel_constraint,
**kwargs
):
super(DiffuseFeatures, self).__init__()
# number of diffusino steps (K in paper)
self.K = num_diffusion_steps
# get regularizer, initializer and constraint for kernel
self.kernel_initializer = kernel_initializer
self.kernel_regularizer = kernel_regularizer
self.kernel_constraint = kernel_constraint
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(self.K,),
name="kernel",
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint,
)
def call(self, inputs):
# Get signal X and adjacency A
X, A = inputs
diffusion_matrix = tf.math.polyval(tf.unstack(self.kernel), A)
diffused_features = tf.matmul(diffusion_matrix, X)
H = tf.math.reduce_sum(diffused_features, axis=-1)
return tf.expand_dims(H, -1)