Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 12, 2023
1 parent 6261c4d commit dc3f754
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 14 deletions.
14 changes: 6 additions & 8 deletions kgcnn/layers_core/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def cat_one(t):
class CastBatchedGraphIndicesToDisjoint(Layer):
"""Cast batched node and edge tensors to a (single) disjoint graph representation of Pytorch Geometric (PyG).
For PyG a batch of graphs is represented by single graph which contains disjoint sub-graphs,
and the batch information is passed as batch ID tensor: `nodes_id` or `edge_id` .
and the batch information is passed as batch ID tensor: `nodes_id` and `edge_id` .
Keras layers can pass unstacked tensors without batch dimension, however, for model input and output
batched tensors is currently built in the framework.
Expand Down Expand Up @@ -238,17 +237,16 @@ def build(self, input_shape):
self.built = True

def compute_output_shape(self, input_shape):
return [tuple([None] + list(input_shape[0][2:])), (None,)]
if self.padded_disjoint:
if input_shape[0] is not None:
return tuple([input_shape[0]-1] + list(input_shape[1:]))
return input_shape

def call(self, inputs: list, **kwargs):
"""Changes node or edge tensors into a Pytorch Geometric (PyG) compatible tensor format.
Args:
inputs (list): List of `[attr, counts_in_batch]` ,
- attr (Tensor): Features are represented by a keras tensor of shape `(batch, N, F, ...)` ,
where N denotes the number of nodes or edges.
- counts_in_batch (Tensor):
inputs (Tensor): Graph labels from a disjoint representation of shape `(batch, ...)` .
Returns:
Tensor: Graph labels of shape `(batch, ...)` .
Expand Down
8 changes: 4 additions & 4 deletions kgcnn/layers_core/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ def __init__(self, pooling_method="scatter_sum", **kwargs):
self._to_aggregate = Aggregate(pooling_method=pooling_method)

def build(self, input_shape):
self._to_aggregate.build([input_shape[0], (None, ), input_shape[1]])
self._to_aggregate.build([input_shape[1], input_shape[2], input_shape[0]])
self.built = True

def compute_output_shape(self, input_shape):
return tuple(list(input_shape[1][:1]) + list(input_shape[0][1:]))
return self._to_aggregate.compute_output_shape([input_shape[1], input_shape[2], input_shape[0]])

def call(self, inputs, **kwargs):
lengths, x, batch = inputs
return self._to_aggregate([x, batch, lengths])
reference, x, idx = inputs
return self._to_aggregate([x, idx, reference])
2 changes: 1 addition & 1 deletion kgcnn/literature_core/GCN/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def make_model(inputs: list = None,

# Output embedding choice
if output_embedding == "graph":
out = PoolingNodes()([count_nodes, n, disjoint_indices]) # will return tensor
out = PoolingNodes()([count_nodes, n, node_id]) # will return tensor
out = MLP(**output_mlp)(out)
out = CastDisjointToGraphLabels(**cast_disjoint_kwargs)(out)
elif output_embedding == "node":
Expand Down
1 change: 1 addition & 0 deletions training_core/hyper/hyper_esol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
{"shape": (), "name": "total_nodes", "dtype": "int64"},
{"shape": (), "name": "total_edges", "dtype": "int64"}
],
"cast_disjoint_kwargs": {"padded_disjoint": True},
"input_node_embedding": {"input_dim": 95, "output_dim": 64},
"input_edge_embedding": {"input_dim": 25, "output_dim": 1},
"gcn_args": {"units": 140, "use_bias": True, "activation": "relu"},
Expand Down
2 changes: 1 addition & 1 deletion training_core/train_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
# The metrics from this script is added to the hyperparameter entry for metrics.
model.compile(**hyper.compile(metrics=metrics))
model.summary()
print("Compiled with jit: %s." % model._jit_compile)
print(" Compiled with jit: %s" % model._jit_compile)
# Run keras model-fit and take time for training.
start = time.time()
hist = model.fit(x_train, y_train,
Expand Down

0 comments on commit dc3f754

Please sign in to comment.