Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 12, 2023
1 parent a6fee65 commit 50e856a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
15 changes: 7 additions & 8 deletions kgcnn/layers_core/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def __init__(self, reverse_indices: bool = True, dtype_batch: str = "int64", dty
def build(self, input_shape):
self.built = True

def compute_output_shape(self, input_shape):
out_shape = [tuple([None] + list(input_shape[0][2:])), tuple(list(reversed(input_shape[1][2:])) + [None]),
(None, ), (None, ), (None, ), (None, )]
if len(input_shape) == 5:
out_shape = out_shape + [tuple([None] + list(input_shape[4][2:]))]
return out_shape
# def compute_output_shape(self, input_shape):
# out_shape = [tuple([None] + list(input_shape[0][2:])), tuple(list(reversed(input_shape[1][2:])) + [None]),
# (None, ), (None, ), (None, ), (None, )]
# if len(input_shape) == 5:
# out_shape = out_shape + [tuple([None] + list(input_shape[4][2:]))]
# return out_shape

def call(self, inputs: list, **kwargs):
"""Changes node and edge indices into a Pytorch Geometric (PyG) compatible tensor format.
Expand Down Expand Up @@ -217,10 +217,9 @@ def call(self, inputs: list, **kwargs):
node_mask_flatten = ops.reshape(node_mask, [-1])
nodes_flatten = pad_left(nodes_flatten)
node_len = cat_one(node_len)
node_mask_flatten = pad_left(node_mask_flatten)
nodes_id = repeat_static_length(ops.arange(ops.shape(node_len)[0], dtype=self.dtype_batch), node_len,
total_repeat_length=ops.shape(nodes_flatten)[0])

if self.padded_disjoint:
nodes_id = ops.where(node_mask_flatten, nodes_id, ops.convert_to_tensor(0, dtype=self.dtype_batch))

return [nodes_flatten, nodes_id, node_len]
Expand Down
1 change: 0 additions & 1 deletion kgcnn/literature_core/Schnet/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ def make_model(inputs: list = None,
if expand_distance:
ed = GaussBasisLayer(**gauss_args)(ed)

print(n, ed, disjoint_indices)
# Model
n = Dense(interaction_args["units"], activation='linear')(n)
for i in range(0, depth):
Expand Down
2 changes: 1 addition & 1 deletion training_core/hyper/hyper_esol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"input_node_embedding": {"input_dim": 95, "output_dim": 64},
"input_edge_embedding": {"input_dim": 25, "output_dim": 1},
"gcn_args": {"units": 140, "use_bias": True, "activation": "relu"},
"depth": 0, "verbose": 10,
"depth": 5, "verbose": 10,
"output_embedding": "graph",
"output_mlp": {"use_bias": [True, True, False], "units": [140, 70, 1],
"activation": ["relu", "relu", "linear"]},
Expand Down

0 comments on commit 50e856a

Please sign in to comment.