Skip to content

Commit

Permalink
update keras core.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Aug 31, 2023
1 parent 4cd90c1 commit 7508c21
Show file tree
Hide file tree
Showing 14 changed files with 1,214 additions and 32 deletions.
2 changes: 1 addition & 1 deletion SDP.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Software Development Plan (SDP)
* [x] Add graph preprocessor from standard dictionary scheme also for ``crystal`` and `molecule` .
* [x] Rework and clean base layers.
* [ ] Add a properly designed transformer layer in ``kgcnn.layers`` .
* [ ] Add a loader for ``Graphlist`` apart from tensor files. Must change dataformat for standard save.
* [ ] Add an element-wise loader for ``Graphlist`` apart from tensor files. Must change dataformat for standard save.
* [ ] Make a ``tf_dataset()`` function to return a generator dataset from `Graphlist` .
* [ ] Add ``JARVISDataset`` . There is already a (yet not fully) port for `kgcnn` .
* [ ] Add package wide Logger Level to change.
Expand Down
1 change: 0 additions & 1 deletion kgcnn/data/force.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class ForceDataset(QMDataset):
└── dataset_name.kgcnn.pickle
Additionally, forces xyz information can be read in with this class.
"""

def __init__(self, data_directory: str = None, dataset_name: str = None, file_name: str = None,
Expand Down
26 changes: 26 additions & 0 deletions kgcnn/layers_core/aggr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from keras_core.layers import Layer
from keras_core import ops

class Aggregate(Layer):

def __init__(self, axis=0, **kwargs):
super(Aggregate, self).__init__(**kwargs)
self.axis = axis
if axis != 0:
raise NotImplementedError
def call(self, inputs, **kwargs):
x, index, dim_size = inputs
# For test only sum scatter, no segment operation no other poolings etc.
# will add all poolings here.
return ops.scatter(index, x, shape=ops.concatenate([[dim_size], ops.shape[1:]]))


class AggregateLocalEdges(Layer):
def __init__(self, **kwargs):
super(AggregateLocalEdges, self).__init__(**kwargs)
self.to_aggregate = Aggregate()

def call(self, inputs, **kwargs):
n, edges, edge_index = inputs
# For test only sum scatter, no segment operation etc.
return self.to_aggregate(edges, edge_index[1], ops.shape(n)[0])
49 changes: 38 additions & 11 deletions kgcnn/layers_core/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ class CastBatchGraphListToPyGDisjoint(Layer):
batched tensors are preferred.
"""

def __init__(self, reverse_indices: bool = True, **kwargs):
def __init__(self, reverse_indices: bool = True, batch_dtype: str = "int64",
batch_info: str = "lengths", **kwargs):
super(CastBatchGraphListToPyGDisjoint, self).__init__(**kwargs)
self.reverse_indices = reverse_indices
self.batch_dtype = batch_dtype
self.batch_info = batch_info
assert batch_info in ["lengths", "mask"], "Wrong format for batch information tensor to expect in call()."

def build(self, input_shape):
return super(CastBatchGraphListToPyGDisjoint, self).build(input_shape)
Expand Down Expand Up @@ -46,16 +50,39 @@ def call(self, inputs: list, **kwargs):

# Case 1: Padded node and edges tensors but with batch dimension at axis 0.
if all_tensor and len(inputs) == 5:
nodes, edges, edge_indices, node_len, edge_len = inputs
node_mask = ops.repeat(ops.expand_dims(ops.arange(
ops.shape(nodes[1])), axis=0), ops.shape(node_len)[0], axis=0) < ops.expand_dims(node_len, axis=-1)
edge_mask = ops.repeat(ops.expand_dims(ops.arange(
ops.shape(nodes[1])), axis=0), ops.shape(node_len)[0], axis=0) < ops.expand_dims(node_len, axis=-1)
edge_indices_flatten = edge_indices[ops.cast(edge_mask, dtype="bool")]
nodes_flatten = nodes[ops.cast(node_mask, dtype="bool")]
edges_flatten = edges[ops.cast(edge_mask, dtype="bool")]

# Case 2: Ragged Tensor input.
if self.batch_info == "lengths":
nodes, edges, edge_indices, node_len, edge_len = inputs
node_len = ops.cast(node_len, dtype=self.batch_dtype)
edge_len = ops.cast(edge_len, dtype=self.batch_dtype)
node_mask = ops.repeat(ops.expand_dims(ops.arange(ops.shape(nodes)[1], dtype=self.batch_dtype), axis=0),
ops.shape(node_len)[0], axis=0) < ops.expand_dims(node_len, axis=-1)
edge_mask = ops.repeat(ops.expand_dims(ops.arange(ops.shape(edges)[1], dtype=self.batch_dtype), axis=0),
ops.shape(edge_len)[0], axis=0) < ops.expand_dims(edge_len, axis=-1)
edge_indices_flatten = edge_indices[edge_mask]
nodes_flatten = nodes[node_mask]
edges_flatten = edges[edge_mask]
elif self.batch_info == "mask":
nodes, edges, edge_indices, node_mask, edge_mask = inputs
edge_indices_flatten = edge_indices[ops.cast(edge_mask, dtype="bool")]
nodes_flatten = nodes[ops.cast(node_mask, dtype="bool")]
edges_flatten = edges[ops.cast(edge_mask, dtype="bool")]
node_len = ops.sum(ops.cast(node_mask, dtype=self.batch_dtype), axis=1)
edge_len = ops.sum(ops.cast(edge_mask, dtype=self.batch_dtype), axis=1)
else:
raise NotImplementedError("Unknown batch information '%s'." % b)

# Case 2: Fixed sized graphs without batch information.
elif all_tensor and len(inputs) == 3:
nodes, edges, edge_indices = inputs
n_shape, e_shape, ei_shape = ops.shape(nodes), ops.shape(edges), ops.shape(edge_indices)
nodes_flatten = ops.reshape(nodes, ops.concatenate([[n_shape[0] * n_shape[1]], n_shape[2:]]))
edges_flatten = ops.reshape(edges, ops.concatenate([[e_shape[0] * e_shape[1]], e_shape[2:]]))
edge_indices_flatten = ops.reshape(
edge_indices, ops.concatenate([[ei_shape[0] * ei_shape[1]], ei_shape[2:]]))
node_len = ops.repeat(ops.cast([n_shape[1]], dtype=self.batch_dtype), n_shape[0])
edge_len = ops.repeat(ops.cast([ei_shape[1]], dtype=self.batch_dtype), ei_shape[0])

# Case: Ragged Tensor input.
# As soon as ragged tensors are supported by Keras-Core.

# Unknown input raises an error.
Expand Down
103 changes: 103 additions & 0 deletions kgcnn/layers_core/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from keras_core.layers import Layer, Dense, Activation
from keras_core import ops
from kgcnn.layers.aggr import AggregateWeightedLocalEdges
from kgcnn.layers.gather import GatherNodesOutgoing


class GCN(Layer):
r"""Graph convolution according to `Kipf et al <https://arxiv.org/abs/1609.02907>`__ .
Computes graph convolution as :math:`\sigma(A_s(WX+b))` where :math:`A_s` is the precomputed and scaled adjacency
matrix. The scaled adjacency matrix is defined by :math:`A_s = D^{-0.5} (A + I) D^{-0.5}` with the degree
matrix :math:`D` . In place of :math:`A_s` , this layers uses edge features (that are the entries of :math:`A_s` )
and edge indices.
.. note::
:math:`A_s` is considered pre-scaled, this is not done by this layer!
If no scaled edge features are available, you could consider use e.g. "mean",
or :obj:`normalize_by_weights` to obtain a similar behaviour that is expected b
y a pre-scaled adjacency matrix input.
Edge features must be possible to broadcast to node features, since they are multiplied with the node features.
Ideally they are weights of shape `(..., 1)` for broadcasting, e.g. entries of :math:`A_s` .
Args:
units (int): Output dimension/ units of dense layer.
pooling_method (str): Pooling method for summing edges. Default is 'segment_sum'.
normalize_by_weights (bool): Normalize the pooled output by the sum of weights. Default is False.
In this case the edge features are considered weights of dimension (...,1) and are summed for each node.
activation (str): Activation. Default is 'kgcnn>leaky_relu'.
use_bias (bool): Use bias. Default is True.
kernel_regularizer: Kernel regularization. Default is None.
bias_regularizer: Bias regularization. Default is None.
activity_regularizer: Activity regularization. Default is None.
kernel_constraint: Kernel constrains. Default is None.
bias_constraint: Bias constrains. Default is None.
kernel_initializer: Initializer for kernels. Default is 'glorot_uniform'.
bias_initializer: Initializer for bias. Default is 'zeros'.
"""

def __init__(self,
units,
pooling_method='sum',
normalize_by_weights=False,
activation='kgcnn>leaky_relu',
use_bias=True,
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
**kwargs):
"""Initialize layer."""
super(GCN, self).__init__(**kwargs)
self.normalize_by_weights = normalize_by_weights
self.pooling_method = pooling_method
self.units = units
kernel_args = {"kernel_regularizer": kernel_regularizer, "activity_regularizer": activity_regularizer,
"bias_regularizer": bias_regularizer, "kernel_constraint": kernel_constraint,
"bias_constraint": bias_constraint, "kernel_initializer": kernel_initializer,
"bias_initializer": bias_initializer, "use_bias": use_bias}
pool_args = {"pooling_method": pooling_method, "normalize_by_weights": normalize_by_weights}

# Layers
self.lay_gather = GatherNodesOutgoing()
self.lay_dense = Dense(units=self.units, activation='linear', **kernel_args)
self.lay_pool = AggregateWeightedLocalEdges(**pool_args)
self.lay_act = Activation(activation)

def call(self, inputs, **kwargs):
"""Forward pass.
Args:
inputs: [nodes, edges, edge_index]
- nodes (tf.RaggedTensor): Node embeddings of shape (batch, [N], F)
- edges (tf.RaggedTensor): Edge or message embeddings of shape (batch, [M], F)
- edge_index (tf.RaggedTensor): Edge indices referring to nodes of shape (batch, [M], 2)
Returns:
tf.RaggedTensor: Node embeddings of shape (batch, [N], F)
"""
node, edges, edge_index = inputs
no = self.lay_dense(node, **kwargs)
no = self.lay_gather([no, edge_index], **kwargs)
nu = self.lay_pool([node, no, edge_index, edges], **kwargs) # Summing for each node connection
out = self.lay_act(nu, **kwargs)
return out

def get_config(self):
"""Update config."""
config = super(GCN, self).get_config()
config.update({"normalize_by_weights": self.normalize_by_weights,
"pooling_method": self.pooling_method, "units": self.units})
conf_dense = self.lay_dense.get_config()
for x in ["kernel_regularizer", "activity_regularizer", "bias_regularizer", "kernel_constraint",
"bias_constraint", "kernel_initializer", "bias_initializer", "use_bias"]:
config.update({x: conf_dense[x]})
conf_act = self.lay_act.get_config()
config.update({"activation": conf_act["activation"]})
return config
27 changes: 27 additions & 0 deletions kgcnn/layers_core/gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from keras_core.layers import Layer
from keras_core import ops


class GatherNodes(Layer):

def __init__(self, split_indices=(0, 1), concat_axis=1, **kwargs):
super(GatherNodes, self).__init__(**kwargs)
self.split_indices = split_indices
self.concat_axis = concat_axis

def call(self, inputs, **kwargs):
x, index = inputs
gathered = [x[i] for i in self.split_indices]
if self.concat_axis is not None:
gathered = ops.concatenate(gathered, axis=self.concat_axis)
return gathered


class GatherNodesOutgoing(GatherNodes):

def __init__(self, **kwargs):
super(GatherNodes, self).__init__(split_indices=1, concat_axis=None, **kwargs)

def call(self, inputs, **kwargs):
return super(GatherNodesOutgoing, self).call(inputs, **kwargs)[0]

Loading

0 comments on commit 7508c21

Please sign in to comment.