Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 13, 2023
1 parent 0e06932 commit fcc0ffa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
10 changes: 6 additions & 4 deletions kgcnn/literature_core/GCN/_model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import keras_core as ks
from keras_core.layers import Dense
from kgcnn.layers_core.modules import Embedding
from kgcnn.layers_core.casting import CastBatchedIndicesToDisjoint, CastDisjointToGraph
from kgcnn.layers_core.casting import CastBatchedIndicesToDisjoint, CastDisjointToGraph, CastBatchedAttributesToDisjoint
from kgcnn.layers_core.conv import GCN
from kgcnn.layers_core.mlp import MLP
from kgcnn.layers_core.pooling import PoolingNodes
from kgcnn.model.utils import update_model_kwargs
from keras_core.backend import backend as backend_to_use

# from keras_core.layers import Activation
# from kgcnn.layers_core.aggr import AggregateWeightedLocalEdges
# from kgcnn.layers_core.gather import GatherNodesOutgoing
Expand Down Expand Up @@ -100,8 +101,9 @@ def make_model(inputs: list = None,
# Make input
model_inputs = [ks.layers.Input(**x) for x in inputs]
batched_nodes, batched_edges, batched_indices, total_nodes, total_edges = model_inputs
n, disjoint_indices, node_id, edge_id, count_nodes, count_edges, e = CastBatchedIndicesToDisjoint(
**cast_disjoint_kwargs)([batched_nodes, batched_indices, total_nodes, total_edges, batched_edges])
n, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = CastBatchedIndicesToDisjoint(
**cast_disjoint_kwargs)([batched_nodes, batched_indices, total_nodes, total_edges])
e, _, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_edges, total_edges])

# Embedding, if no feature dimension
if len(inputs[0]['shape']) < 2:
Expand All @@ -123,7 +125,7 @@ def make_model(inputs: list = None,

# Output embedding choice
if output_embedding == "graph":
out = PoolingNodes()([count_nodes, n, node_id]) # will return tensor
out = PoolingNodes()([count_nodes, n, batch_id_node]) # will return tensor
out = MLP(**output_mlp)(out)
out = CastDisjointToGraph(**cast_disjoint_kwargs)(out)
elif output_embedding == "node":
Expand Down
9 changes: 4 additions & 5 deletions kgcnn/literature_core/Schnet/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,19 @@ def make_model(inputs: list = None,
# Make input
model_inputs = [ks.layers.Input(**x) for x in inputs]
batched_nodes, batched_x, batched_indices, total_nodes, total_edges = model_inputs
n, disjoint_indices, node_id, edge_id, count_nodes, count_edges = CastBatchedIndicesToDisjoint(
n, disjoint_indices, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = CastBatchedIndicesToDisjoint(
**cast_disjoint_kwargs)([batched_nodes, batched_indices, total_nodes, total_edges])

# Optional Embedding.
if len(inputs[0]['shape']) < 2:
n = Embedding(**input_node_embedding)(n)

if make_distance:
x, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_x, total_nodes])
x, _, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_x, total_nodes])
pos1, pos2 = NodePosition()([x, disjoint_indices])
ed = NodeDistanceEuclidean()([pos1, pos2])
else:
ed, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_x, total_edges])
ed, _, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_x, total_edges])

if expand_distance:
ed = GaussBasisLayer(**gauss_args)(ed)
Expand All @@ -147,7 +147,7 @@ def make_model(inputs: list = None,

# Output embedding choice
if output_embedding == 'graph':
out = PoolingNodes(**node_pooling_args)([count_nodes, n, node_id])
out = PoolingNodes(**node_pooling_args)([count_nodes, n, batch_id_node])
if use_output_mlp:
out = MLP(**output_mlp)(out)
out = CastDisjointToGraph(**cast_disjoint_kwargs)(out)
Expand All @@ -159,7 +159,6 @@ def make_model(inputs: list = None,
raise ValueError("Unsupported output embedding for mode `SchNet` .")

model = ks.models.Model(inputs=model_inputs, outputs=out)

model.__kgcnn_model_version__ = __model_version__
return model

Expand Down

0 comments on commit fcc0ffa

Please sign in to comment.