From dc3f754534fe39d75b7daa6a76969919eb9eafc7 Mon Sep 17 00:00:00 2001 From: PatReis Date: Tue, 12 Sep 2023 16:55:17 +0200 Subject: [PATCH] continue keras core integration --- kgcnn/layers_core/casting.py | 14 ++++++-------- kgcnn/layers_core/pooling.py | 8 ++++---- kgcnn/literature_core/GCN/_model.py | 2 +- training_core/hyper/hyper_esol.py | 1 + training_core/train_graph.py | 2 +- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/kgcnn/layers_core/casting.py b/kgcnn/layers_core/casting.py index d83a87a6..bb0e2ee7 100644 --- a/kgcnn/layers_core/casting.py +++ b/kgcnn/layers_core/casting.py @@ -14,8 +14,7 @@ def cat_one(t): class CastBatchedGraphIndicesToDisjoint(Layer): """Cast batched node and edge tensors to a (single) disjoint graph representation of Pytorch Geometric (PyG). For PyG a batch of graphs is represented by single graph which contains disjoint sub-graphs, - and the batch information is passed as batch ID tensor: `nodes_id` or `edge_id` . - + and the batch information is passed as batch ID tensor: `nodes_id` and `edge_id` . Keras layers can pass unstacked tensors without batch dimension, however, for model input and output batched tensors is currently built in the framework. @@ -238,17 +237,16 @@ def build(self, input_shape): self.built = True def compute_output_shape(self, input_shape): - return [tuple([None] + list(input_shape[0][2:])), (None,)] + if self.padded_disjoint: + if input_shape[0] is not None: + return tuple([input_shape[0]-1] + list(input_shape[1:])) + return input_shape def call(self, inputs: list, **kwargs): """Changes node or edge tensors into a Pytorch Geometric (PyG) compatible tensor format. Args: - inputs (list): List of `[attr, counts_in_batch]` , - - - attr (Tensor): Features are represented by a keras tensor of shape `(batch, N, F, ...)` , - where N denotes the number of nodes or edges. - - counts_in_batch (Tensor): + inputs (Tensor): Graph labels from a disjoint representation of shape `(batch, ...)` . Returns: Tensor: Graph labels of shape `(batch, ...)` . diff --git a/kgcnn/layers_core/pooling.py b/kgcnn/layers_core/pooling.py index 555561bc..a8f962e0 100644 --- a/kgcnn/layers_core/pooling.py +++ b/kgcnn/layers_core/pooling.py @@ -12,12 +12,12 @@ def __init__(self, pooling_method="scatter_sum", **kwargs): self._to_aggregate = Aggregate(pooling_method=pooling_method) def build(self, input_shape): - self._to_aggregate.build([input_shape[0], (None, ), input_shape[1]]) + self._to_aggregate.build([input_shape[1], input_shape[2], input_shape[0]]) self.built = True def compute_output_shape(self, input_shape): - return tuple(list(input_shape[1][:1]) + list(input_shape[0][1:])) + return self._to_aggregate.compute_output_shape([input_shape[1], input_shape[2], input_shape[0]]) def call(self, inputs, **kwargs): - lengths, x, batch = inputs - return self._to_aggregate([x, batch, lengths]) + reference, x, idx = inputs + return self._to_aggregate([x, idx, reference]) diff --git a/kgcnn/literature_core/GCN/_model.py b/kgcnn/literature_core/GCN/_model.py index 1346bd8e..0e7f03d9 100644 --- a/kgcnn/literature_core/GCN/_model.py +++ b/kgcnn/literature_core/GCN/_model.py @@ -123,7 +123,7 @@ def make_model(inputs: list = None, # Output embedding choice if output_embedding == "graph": - out = PoolingNodes()([count_nodes, n, disjoint_indices]) # will return tensor + out = PoolingNodes()([count_nodes, n, node_id]) # will return tensor out = MLP(**output_mlp)(out) out = CastDisjointToGraphLabels(**cast_disjoint_kwargs)(out) elif output_embedding == "node": diff --git a/training_core/hyper/hyper_esol.py b/training_core/hyper/hyper_esol.py index 5b1e0400..112dd1ff 100644 --- a/training_core/hyper/hyper_esol.py +++ b/training_core/hyper/hyper_esol.py @@ -12,6 +12,7 @@ {"shape": (), "name": "total_nodes", "dtype": "int64"}, {"shape": (), "name": "total_edges", "dtype": "int64"} ], + "cast_disjoint_kwargs": {"padded_disjoint": True}, "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 25, "output_dim": 1}, "gcn_args": {"units": 140, "use_bias": True, "activation": "relu"}, diff --git a/training_core/train_graph.py b/training_core/train_graph.py index 1b0cb017..169fb6c9 100644 --- a/training_core/train_graph.py +++ b/training_core/train_graph.py @@ -112,7 +112,7 @@ # The metrics from this script is added to the hyperparameter entry for metrics. model.compile(**hyper.compile(metrics=metrics)) model.summary() - print("Compiled with jit: %s." % model._jit_compile) + print(" Compiled with jit: %s" % model._jit_compile) # Run keras model-fit and take time for training. start = time.time() hist = model.fit(x_train, y_train,