Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 8, 2023
1 parent aa7fb04 commit 4fb3044
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 20 deletions.
5 changes: 4 additions & 1 deletion kgcnn/layers_core/aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def __init__(self, pooling_method: str = "scatter_sum", axis=0, **kwargs):
self._use_scatter = "scatter" in pooling_method

def build(self, input_shape):
# Nothing to build here. No sub-layers.
super(Aggregate, self).build(input_shape)

def compute_output_shape(self, input_shape):
assert len(input_shape) == 3
x_shape, _, dim_size = input_shape
Expand Down Expand Up @@ -82,6 +83,7 @@ def build(self, input_shape):
node_shape, edges_shape, edge_index_shape, weights_shape = input_shape
self.to_aggregate.build((edges_shape, edge_index_shape[1:], node_shape))
self.to_aggregate_weights.build((weights_shape, edge_index_shape[1:], node_shape))
self.built = True

def compute_output_shape(self, input_shape):
assert len(input_shape) == 4
Expand Down Expand Up @@ -135,6 +137,7 @@ def build(self, input_shape):
assert len(input_shape) == 4
node_shape, edges_shape, attention_shape, edge_index_shape = input_shape
self.to_aggregate.build((edges_shape, edge_index_shape[1:], node_shape))
self.built = True

def compute_output_shape(self, input_shape):
assert len(input_shape) == 4
Expand Down
6 changes: 3 additions & 3 deletions kgcnn/layers_core/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def __init__(self, reverse_indices: bool = True, dtype_batch: str = "int64", dty
self.reverse_indices = reverse_indices
self.dtype_index = dtype_index
self.dtype_batch = dtype_batch

def build(self, input_shape):
return super(CastBatchedGraphIndicesToPyGDisjoint, self).build(input_shape)
super(CastBatchedGraphIndicesToPyGDisjoint, self).build(input_shape)

def call(self, inputs: list, **kwargs):
"""Changes node and edge indices into a Pytorch Geometric (PyG) compatible tensor format.
Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(self, reverse_indices: bool = True, dtype_batch: str = "int64", **k
self.dtype_batch = dtype_batch

def build(self, input_shape):
return super(CastBatchedGraphAttributesToPyGDisjoint, self).build(input_shape)
super(CastBatchedGraphAttributesToPyGDisjoint, self).build(input_shape)

def call(self, inputs: list, **kwargs):
"""Changes node or edge tensors into a Pytorch Geometric (PyG) compatible tensor format.
Expand Down
15 changes: 2 additions & 13 deletions kgcnn/layers_core/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,9 @@ def __init__(self,
self.layer_dense = Dense(units=self.units, activation='linear', **kernel_args)
self.layer_pool = AggregateWeightedLocalEdges(**pool_args)
self.layer_act = Activation(activation)

def build(self, input_shape):
assert isinstance(input_shape, list), "Require list input"
self.layer_dense.build(input_shape[0])
dense_shape = self.layer_dense.compute_output_shape(input_shape[0])
self.layer_gather.build([dense_shape, input_shape[2]])
gather_shape = self.layer_gather.compute_output_shape([dense_shape, input_shape[2]])
self.layer_pool.build([input_shape[0], gather_shape, input_shape[2], input_shape[1]])
pool_shape = self.layer_pool.compute_output_shape(
[input_shape[0], gather_shape, input_shape[2], input_shape[1]])
self.layer_act.build(pool_shape)
self.built = True
super(GCN, self).build(input_shape)

def call(self, inputs, **kwargs):
"""Forward pass.
Expand Down Expand Up @@ -165,7 +156,6 @@ def __init__(self, units,
self.lay_mult = Multiply()

def build(self, input_shape):
"""Build layer."""
super(SchNetCFconv, self).build(input_shape)

def call(self, inputs, **kwargs):
Expand Down Expand Up @@ -251,7 +241,6 @@ def __init__(self,
self.lay_add = Add()

def build(self, input_shape):
"""Build layer."""
super(SchNetInteraction, self).build(input_shape)

def call(self, inputs, **kwargs):
Expand Down
7 changes: 4 additions & 3 deletions kgcnn/layers_core/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def _compute_gathered_shape(self, input_shape):
return xs

def build(self, input_shape):
# We could call build on concatenate layer.
xs = self._compute_gathered_shape(input_shape)
if self.concat_axis is not None:
self._concat.build(xs)
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(self, selection_index: int = 0, **kwargs):
self.selection_index = selection_index

def build(self, input_shape):
return super(GatherNodesOutgoing, self).build(input_shape)
super(GatherNodesOutgoing, self).build(input_shape)

def compute_output_shape(self, input_shape):
assert len(input_shape) == 2
Expand All @@ -68,7 +69,7 @@ def __init__(self, selection_index: int = 1, **kwargs):
self.selection_index = selection_index

def build(self, input_shape):
return super(GatherNodesIngoing, self).build(input_shape)
super(GatherNodesIngoing, self).build(input_shape)

def compute_output_shape(self, input_shape):
assert len(input_shape) == 2
Expand All @@ -87,7 +88,7 @@ def __init__(self, selection_index: int = 1, **kwargs):
self.selection_index = selection_index

def build(self, input_shape):
return super(GatherState, self).build(input_shape)
super(GatherState, self).build(input_shape)

def compute_output_shape(self, input_shape):
assert len(input_shape) == 2
Expand Down

0 comments on commit 4fb3044

Please sign in to comment.