-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
1,214 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from keras_core.layers import Layer | ||
from keras_core import ops | ||
|
||
class Aggregate(Layer): | ||
|
||
def __init__(self, axis=0, **kwargs): | ||
super(Aggregate, self).__init__(**kwargs) | ||
self.axis = axis | ||
if axis != 0: | ||
raise NotImplementedError | ||
def call(self, inputs, **kwargs): | ||
x, index, dim_size = inputs | ||
# For test only sum scatter, no segment operation no other poolings etc. | ||
# will add all poolings here. | ||
return ops.scatter(index, x, shape=ops.concatenate([[dim_size], ops.shape[1:]])) | ||
|
||
|
||
class AggregateLocalEdges(Layer): | ||
def __init__(self, **kwargs): | ||
super(AggregateLocalEdges, self).__init__(**kwargs) | ||
self.to_aggregate = Aggregate() | ||
|
||
def call(self, inputs, **kwargs): | ||
n, edges, edge_index = inputs | ||
# For test only sum scatter, no segment operation etc. | ||
return self.to_aggregate(edges, edge_index[1], ops.shape(n)[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from keras_core.layers import Layer, Dense, Activation | ||
from keras_core import ops | ||
from kgcnn.layers.aggr import AggregateWeightedLocalEdges | ||
from kgcnn.layers.gather import GatherNodesOutgoing | ||
|
||
|
||
class GCN(Layer): | ||
r"""Graph convolution according to `Kipf et al <https://arxiv.org/abs/1609.02907>`__ . | ||
Computes graph convolution as :math:`\sigma(A_s(WX+b))` where :math:`A_s` is the precomputed and scaled adjacency | ||
matrix. The scaled adjacency matrix is defined by :math:`A_s = D^{-0.5} (A + I) D^{-0.5}` with the degree | ||
matrix :math:`D` . In place of :math:`A_s` , this layers uses edge features (that are the entries of :math:`A_s` ) | ||
and edge indices. | ||
.. note:: | ||
:math:`A_s` is considered pre-scaled, this is not done by this layer! | ||
If no scaled edge features are available, you could consider use e.g. "mean", | ||
or :obj:`normalize_by_weights` to obtain a similar behaviour that is expected b | ||
y a pre-scaled adjacency matrix input. | ||
Edge features must be possible to broadcast to node features, since they are multiplied with the node features. | ||
Ideally they are weights of shape `(..., 1)` for broadcasting, e.g. entries of :math:`A_s` . | ||
Args: | ||
units (int): Output dimension/ units of dense layer. | ||
pooling_method (str): Pooling method for summing edges. Default is 'segment_sum'. | ||
normalize_by_weights (bool): Normalize the pooled output by the sum of weights. Default is False. | ||
In this case the edge features are considered weights of dimension (...,1) and are summed for each node. | ||
activation (str): Activation. Default is 'kgcnn>leaky_relu'. | ||
use_bias (bool): Use bias. Default is True. | ||
kernel_regularizer: Kernel regularization. Default is None. | ||
bias_regularizer: Bias regularization. Default is None. | ||
activity_regularizer: Activity regularization. Default is None. | ||
kernel_constraint: Kernel constrains. Default is None. | ||
bias_constraint: Bias constrains. Default is None. | ||
kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'. | ||
bias_initializer: Initializer for bias. Default is 'zeros'. | ||
""" | ||
|
||
def __init__(self, | ||
units, | ||
pooling_method='sum', | ||
normalize_by_weights=False, | ||
activation='kgcnn>leaky_relu', | ||
use_bias=True, | ||
kernel_regularizer=None, | ||
bias_regularizer=None, | ||
activity_regularizer=None, | ||
kernel_constraint=None, | ||
bias_constraint=None, | ||
kernel_initializer='glorot_uniform', | ||
bias_initializer='zeros', | ||
**kwargs): | ||
"""Initialize layer.""" | ||
super(GCN, self).__init__(**kwargs) | ||
self.normalize_by_weights = normalize_by_weights | ||
self.pooling_method = pooling_method | ||
self.units = units | ||
kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer, | ||
"bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint, | ||
"bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer, | ||
"bias_initializer": bias_initializer, "use_bias": use_bias} | ||
pool_args = {"pooling_method": pooling_method, "normalize_by_weights": normalize_by_weights} | ||
|
||
# Layers | ||
self.lay_gather = GatherNodesOutgoing() | ||
self.lay_dense = Dense(units=self.units, activation='linear', **kernel_args) | ||
self.lay_pool = AggregateWeightedLocalEdges(**pool_args) | ||
self.lay_act = Activation(activation) | ||
|
||
def call(self, inputs, **kwargs): | ||
"""Forward pass. | ||
Args: | ||
inputs: [nodes, edges, edge_index] | ||
- nodes (tf.RaggedTensor): Node embeddings of shape (batch, [N], F) | ||
- edges (tf.RaggedTensor): Edge or message embeddings of shape (batch, [M], F) | ||
- edge_index (tf.RaggedTensor): Edge indices referring to nodes of shape (batch, [M], 2) | ||
Returns: | ||
tf.RaggedTensor: Node embeddings of shape (batch, [N], F) | ||
""" | ||
node, edges, edge_index = inputs | ||
no = self.lay_dense(node, **kwargs) | ||
no = self.lay_gather([no, edge_index], **kwargs) | ||
nu = self.lay_pool([node, no, edge_index, edges], **kwargs) # Summing for each node connection | ||
out = self.lay_act(nu, **kwargs) | ||
return out | ||
|
||
def get_config(self): | ||
"""Update config.""" | ||
config = super(GCN, self).get_config() | ||
config.update({"normalize_by_weights": self.normalize_by_weights, | ||
"pooling_method": self.pooling_method, "units": self.units}) | ||
conf_dense = self.lay_dense.get_config() | ||
for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint", | ||
"bias_constraint", "kernel_initializer", "bias_initializer", "use_bias"]: | ||
config.update({x: conf_dense[x]}) | ||
conf_act = self.lay_act.get_config() | ||
config.update({"activation": conf_act["activation"]}) | ||
return config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from keras_core.layers import Layer | ||
from keras_core import ops | ||
|
||
|
||
class GatherNodes(Layer): | ||
|
||
def __init__(self, split_indices=(0, 1), concat_axis=1, **kwargs): | ||
super(GatherNodes, self).__init__(**kwargs) | ||
self.split_indices = split_indices | ||
self.concat_axis = concat_axis | ||
|
||
def call(self, inputs, **kwargs): | ||
x, index = inputs | ||
gathered = [x[i] for i in self.split_indices] | ||
if self.concat_axis is not None: | ||
gathered = ops.concatenate(gathered, axis=self.concat_axis) | ||
return gathered | ||
|
||
|
||
class GatherNodesOutgoing(GatherNodes): | ||
|
||
def __init__(self, **kwargs): | ||
super(GatherNodes, self).__init__(split_indices=1, concat_axis=None, **kwargs) | ||
|
||
def call(self, inputs, **kwargs): | ||
return super(GatherNodesOutgoing, self).call(inputs, **kwargs)[0] | ||
|
Oops, something went wrong.