Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 13, 2023
1 parent a1eb7d7 commit 84540c7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
7 changes: 3 additions & 4 deletions kgcnn/layers_core/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,8 @@ def call(self, inputs: list, **kwargs):
disjoint_indices = edge_indices_flatten + ops.cast(offset_edge_indices, edge_indices_flatten.dtype)
edge_mask_flatten = ops.expand_dims(edge_mask_flatten, axis=-1)
disjoint_indices = ops.where(edge_mask_flatten, disjoint_indices, 0)

node_len = ops.concatenate([ops.sum(node_len_flat[1:] - node_len, keepdims=True), node_len], axis=0)
edge_len = ops.concatenate([ops.sum(edge_len_flat[1:] - edge_len, keepdims=True), edge_len], axis=0)
node_len = ops.concatenate([ops.sum(node_len_flat[1:] - node_len, axis=0, keepdims=True), node_len], axis=0)
edge_len = ops.concatenate([ops.sum(edge_len_flat[1:] - edge_len, axis=0, keepdims=True), edge_len], axis=0)

# Transpose edge indices.
disjoint_indices = ops.transpose(disjoint_indices)
Expand Down Expand Up @@ -285,7 +284,7 @@ def call(self, inputs: list, **kwargs):
total_repeat_length=ops.shape(nodes_flatten)[0])
graph_id_node = ops.where(node_mask_flatten, graph_id, 0)
node_id = ops.where(node_mask_flatten, node_id, 0)
node_len = ops.concatenate([ops.sum(node_len_flat[1:] - node_len, keepdims=True), node_len], axis=0)
node_len = ops.concatenate([ops.sum(node_len_flat[1:] - node_len, axis=0, keepdims=True), node_len], axis=0)

return [nodes_flatten, graph_id_node, node_id, node_len]

Expand Down
6 changes: 4 additions & 2 deletions kgcnn/literature_core/GCN/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"gcn_args": {"units": 100, "use_bias": True, "activation": "relu", "pooling_method": "sum"},
"depth": 3,
"verbose": 10,
"node_pooling_args": {"pooling_method": "scatter_sum"},
"output_embedding": "graph",
"output_to_tensor": True,
"output_mlp": {"use_bias": [True, True, False], "units": [25, 10, 1],
Expand All @@ -57,6 +58,7 @@ def make_model(inputs: list = None,
gcn_args: dict = None,
name: str = None,
verbose: int = None,
node_pooling_args: dict = None,
output_embedding: str = None,
output_to_tensor: bool = None,
output_mlp: dict = None):
Expand Down Expand Up @@ -86,6 +88,7 @@ def make_model(inputs: list = None,
gcn_args (dict): Dictionary of layer arguments unpacked in :obj:`GCN` convolutional layer.
name (str): Name of the model.
verbose (int): Level of print output.
node_pooling_args (dict): Dictionary of layer arguments unpacked in :obj:`PoolingNodes` layer.
output_embedding (str): Main embedding task for graph network. Either "node", "edge" or "graph".
output_to_tensor (bool): Whether to cast model output to :obj:`Tensor`.
output_mlp (dict): Dictionary of layer arguments unpacked in the final classification :obj:`MLP` layer block.
Expand Down Expand Up @@ -113,7 +116,6 @@ def make_model(inputs: list = None,

# Model
n = Dense(gcn_args["units"], use_bias=True, activation='linear')(n) # Map to units

for i in range(0, depth):
n = GCN(**gcn_args)([n, e, disjoint_indices])

Expand All @@ -125,7 +127,7 @@ def make_model(inputs: list = None,

# Output embedding choice
if output_embedding == "graph":
out = PoolingNodes()([count_nodes, n, batch_id_node]) # will return tensor
out = PoolingNodes(**node_pooling_args)([count_nodes, n, batch_id_node]) # will return tensor
out = MLP(**output_mlp)(out)
out = CastDisjointToGraph(**cast_disjoint_kwargs)(out)
elif output_embedding == "node":
Expand Down
12 changes: 6 additions & 6 deletions training_core/results/ESOLDataset/GCN/GCN_ESOLDataset_score.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ cuda_device_id: '[0]'
cuda_device_memory: '[None]'
cuda_device_name: '[CpuDevice(id=0)]'
data_unit: mol/L
date_time: '2023-09-12 21:03:25'
date_time: '2023-09-13 15:41:58'
epochs:
- 800
- 800
Expand Down Expand Up @@ -129,11 +129,11 @@ scaled_root_mean_squared_error:
- 0.06261409819126129
seed: 42
time_list:
- '0:09:12.199272'
- '0:09:02.642175'
- '0:08:50.004458'
- '0:08:49.564425'
- '0:08:55.888719'
- '0:09:17.532198'
- '0:09:22.549678'
- '0:09:27.358056'
- '0:09:40.644600'
- '0:09:28.633570'
val_loss:
- 0.2157980352640152
- 0.18899810314178467
Expand Down

0 comments on commit 84540c7

Please sign in to comment.