Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 12, 2023
1 parent dc3f754 commit 7e57f4c
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 105 deletions.
4 changes: 4 additions & 0 deletions kgcnn/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ def scatter_reduce_mean(indices, values, shape):
counts = jnp.zeros(shape, values.dtype)
counts.at[indices].add(jnp.ones_like(values))
return zeros.at[indices].add(values)/counts


def repeat_static_length(x, repeats, axis=None, total_repeat_length: int = None):
return jnp.repeat(x, repeats=repeats, axis=axis, total_repeat_length=total_repeat_length)
4 changes: 4 additions & 0 deletions kgcnn/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ def scatter_reduce_softmax(indices, values, shape):
values_exp_sum = tf.scatter_nd(indices, values_exp, tf.cast(shape, dtype="int64"))
values_exp_sum = tf.gather(values_exp_sum, indices, axis=0)
return values_exp / values_exp_sum


def repeat_static_length(x, repeats, axis=None, total_repeat_length: int = None):
return tf.repeat(x, repeats=repeats, axis=axis)
5 changes: 5 additions & 0 deletions kgcnn/backend/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ def scatter_reduce_softmax(indices, values, shape):
indices = torch.unsqueeze(indices, dim=-1)
return torch.zeros(*shape, dtype=values.dtype, device=values.device).scatter_reduce(
0, torch.broadcast_to(indices, values.shape), values, reduce='sum')


def repeat_static_length(x, repeats, axis=None, total_repeat_length: int = None):
# from keras_core.backend.torch.numpy import repeat
return torch.repeat_interleave(x, repeats, dim=axis)
11 changes: 8 additions & 3 deletions kgcnn/layers_core/casting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from keras_core.layers import Layer
from keras_core import ops
from kgcnn.ops_core.core import repeat_static_length
# from keras_core.backend import backend


Expand Down Expand Up @@ -118,15 +119,19 @@ def call(self, inputs: list, **kwargs):
if self.dtype_index is not None:
edge_indices_flatten = ops.cast(edge_indices_flatten, dtype=self.dtype_index)

nodes_id = ops.repeat(ops.arange(ops.shape(node_len)[0], dtype=self.dtype_batch), node_len)
edges_id = ops.repeat(ops.arange(ops.shape(edge_len)[0], dtype=self.dtype_batch), edge_len)
nodes_id = repeat_static_length(ops.arange(ops.shape(node_len)[0], dtype=self.dtype_batch), node_len,
total_repeat_length=ops.shape(node_mask_flatten)[0])
edges_id = repeat_static_length(ops.arange(ops.shape(edge_len)[0], dtype=self.dtype_batch), edge_len,
total_repeat_length=ops.shape(edge_mask_flatten)[0])

if self.padded_disjoint:
nodes_id = ops.where(node_mask_flatten, nodes_id, ops.convert_to_tensor(0, dtype=self.dtype_batch))
edges_id = ops.where(edge_mask_flatten, edges_id, ops.convert_to_tensor(0, dtype=self.dtype_batch))

node_splits = ops.pad(ops.cumsum(node_len), [[1, 0]])
offset_edge_indices = ops.expand_dims(ops.repeat(node_splits[:-1], edge_len), axis=-1)
offset_edge_indices = ops.expand_dims(
repeat_static_length(node_splits[:-1], edge_len, total_repeat_length=ops.shape(edge_indices_flatten)[0])
, axis=-1)
offset_edge_indices = ops.broadcast_to(offset_edge_indices, ops.shape(edge_indices_flatten))

disjoint_indices = edge_indices_flatten + ops.cast(offset_edge_indices, edge_indices_flatten.dtype)
Expand Down
21 changes: 21 additions & 0 deletions kgcnn/ops_core/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from keras_core.ops import any_symbolic_tensors
from keras_core.ops.numpy import Repeat
import kgcnn.backend as kgcnn_backend


def repeat_static_length(x, repeats, axis=None, total_repeat_length: int = None):
"""Repeat each element of a tensor after themselves.
Args:
x: Input tensor.
repeats: The number of repetitions for each element.
axis: The axis along which to repeat values. By default, use
the flattened input array, and return a flat output array.
total_repeat_length: length of all repeats.
Returns:
Output tensor.
"""
if any_symbolic_tensors((x,)):
return Repeat(repeats, axis=axis).symbolic_call(x)
return kgcnn_backend.repeat_static_length(x, repeats, axis=axis, total_repeat_length=total_repeat_length)
2 changes: 1 addition & 1 deletion training_core/hyper/hyper_esol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"input_node_embedding": {"input_dim": 95, "output_dim": 64},
"input_edge_embedding": {"input_dim": 25, "output_dim": 1},
"gcn_args": {"units": 140, "use_bias": True, "activation": "relu"},
"depth": 5, "verbose": 10,
"depth": 0, "verbose": 10,
"output_embedding": "graph",
"output_mlp": {"use_bias": [True, True, False], "units": [140, 70, 1],
"activation": ["relu", "relu", "linear"]},
Expand Down
200 changes: 100 additions & 100 deletions training_core/results/ESOLDataset/GCN/GCN_ESOLDataset_score.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
OS: nt_win32
backend: torch
cuda_available: 'True'
backend: jax
cuda_available: 'False'
cuda_device_id: '[0]'
cuda_device_memory: '[{''allocated'': 0.0, ''cached'': 0.0}]'
cuda_device_name: '[''NVIDIA GeForce GTX 1060 6GB'']'
cuda_device_memory: '[None]'
cuda_device_name: '[CpuDevice(id=0)]'
data_unit: mol/L
date_time: '2023-09-12 15:23:18'
date_time: '2023-09-12 18:43:51'
epochs:
- 800
- 800
Expand All @@ -21,134 +21,134 @@ learning_rate:
- 5.1727271056734025e-05
- 5.1727271056734025e-05
loss:
- 0.0066009266301989555
- 0.008980654180049896
- 0.004737056326121092
- 0.008513851091265678
- 0.005296375136822462
- 0.03053712472319603
- 0.03404470160603523
- 0.03311067447066307
- 0.033033594489097595
- 0.03143709525465965
max_learning_rate:
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
max_loss:
- 0.7155899405479431
- 0.774999737739563
- 0.7535335421562195
- 0.7142425775527954
- 0.7380039095878601
- 0.9184737801551819
- 0.8267390131950378
- 0.7701554894447327
- 0.8929275870323181
- 0.9780256748199463
max_scaled_mean_absolute_error:
- 1.504331111907959
- 1.6446155309677124
- 1.5878634452819824
- 1.4825392961502075
- 1.5574897527694702
- 1.935563087463379
- 1.766575574874878
- 1.6337051391601562
- 1.8669281005859375
- 2.0824332237243652
max_scaled_root_mean_squared_error:
- 1.8573236465454102
- 2.033061981201172
- 1.9877549409866333
- 1.8397737741470337
- 1.9624701738357544
- 2.5540788173675537
- 2.321408748626709
- 2.1065993309020996
- 2.452350616455078
- 2.6977622509002686
max_val_loss:
- 0.3724708557128906
- 0.30009183287620544
- 0.2972280979156494
- 0.30284008383750916
- 0.3771456182003021
- 0.3772163987159729
- 0.3763466477394104
- 0.33287692070007324
- 0.4974099397659302
- 0.4094856083393097
max_val_scaled_mean_absolute_error:
- 0.7121237516403198
- 0.6909666061401367
- 0.5778587460517883
- 0.6605700850486755
- 0.7213104963302612
- 0.8177486062049866
- 0.8109288811683655
- 0.6679871678352356
- 0.988940954208374
- 0.801365852355957
max_val_scaled_root_mean_squared_error:
- 0.987949788570404
- 0.8955281376838684
- 0.8078282475471497
- 0.8567005395889282
- 0.9503365755081177
- 1.0927287340164185
- 1.0821540355682373
- 0.8984454870223999
- 1.239497184753418
- 1.056252121925354
min_learning_rate:
- 5.1727271056734025e-05
- 5.1727271056734025e-05
- 5.1727271056734025e-05
- 5.1727271056734025e-05
- 5.1727271056734025e-05
min_loss:
- 0.006196807138621807
- 0.007041560485959053
- 0.004737056326121092
- 0.008513851091265678
- 0.005086708813905716
- 0.03053712472319603
- 0.03404470160603523
- 0.03311067447066307
- 0.03280334919691086
- 0.030585210770368576
min_scaled_mean_absolute_error:
- 0.01290896162390709
- 0.01511739008128643
- 0.010167216882109642
- 0.018042149022221565
- 0.010847717523574829
- 0.0633389949798584
- 0.07146354764699936
- 0.07079430669546127
- 0.06939832121133804
- 0.0653739720582962
min_scaled_root_mean_squared_error:
- 0.07693216949701309
- 0.062237679958343506
- 0.0496278777718544
- 0.07619256526231766
- 0.05293257534503937
- 0.16624480485916138
- 0.1832047998905182
- 0.18838366866111755
- 0.1890300065279007
- 0.17916052043437958
min_val_loss:
- 0.2154006063938141
- 0.18383672833442688
- 0.2109178751707077
- 0.19489336013793945
- 0.19204500317573547
- 0.22608399391174316
- 0.22066998481750488
- 0.21403798460960388
- 0.2186526656150818
- 0.20758692920207977
min_val_scaled_mean_absolute_error:
- 0.46074214577674866
- 0.42304232716560364
- 0.42409372329711914
- 0.4583303928375244
- 0.42986610531806946
- 0.4922861158847809
- 0.4800896644592285
- 0.48117125034332275
- 0.5147459506988525
- 0.45930004119873047
min_val_scaled_root_mean_squared_error:
- 0.6589117050170898
- 0.5896828174591064
- 0.6173902153968811
- 0.641599714756012
- 0.5974742770195007
- 0.718582272529602
- 0.6697787046432495
- 0.7009758353233337
- 0.7098767757415771
- 0.6274582743644714
model_class: make_model
model_name: GCN
model_version: 2023.09.30
multi_target_indices: null
number_histories: 5
scaled_mean_absolute_error:
- 0.01290896162390709
- 0.016202447935938835
- 0.010255923494696617
- 0.018042149022221565
- 0.010847717523574829
- 0.0633389949798584
- 0.07158304750919342
- 0.07079430669546127
- 0.06939832121133804
- 0.0659930482506752
scaled_root_mean_squared_error:
- 0.0771808922290802
- 0.06273077428340912
- 0.04992813244462013
- 0.07626795023679733
- 0.05293257534503937
- 0.16624480485916138
- 0.183350071310997
- 0.18849699199199677
- 0.189775288105011
- 0.18083412945270538
seed: 42
time_list:
- '0:09:39.162853'
- '0:08:58.291624'
- '0:09:58.646090'
- '0:08:58.662076'
- '0:09:05.079421'
- '0:01:33.268670'
- '0:01:30.225590'
- '0:01:31.639611'
- '0:01:37.904172'
- '0:01:42.706901'
val_loss:
- 0.2381785809993744
- 0.1950930505990982
- 0.22705580294132233
- 0.19656625390052795
- 0.22198894619941711
- 0.23893700540065765
- 0.23162303864955902
- 0.2550176978111267
- 0.22663560509681702
- 0.23079481720924377
val_scaled_mean_absolute_error:
- 0.5136747360229492
- 0.44999390840530396
- 0.4479023516178131
- 0.46139630675315857
- 0.4667053520679474
- 0.5077677369117737
- 0.5055126547813416
- 0.503480076789856
- 0.5220910310745239
- 0.48052480816841125
val_scaled_root_mean_squared_error:
- 0.7505388855934143
- 0.6257898807525635
- 0.6442564725875854
- 0.6493059992790222
- 0.657427966594696
- 0.7481772303581238
- 0.7289010286331177
- 0.7459499835968018
- 0.7468581795692444
- 0.6586126089096069
2 changes: 1 addition & 1 deletion training_core/results/ESOLDataset/GCN/GCN_hyper.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"model": {"class_name": "make_model", "module_name": "kgcnn.literature_core.GCN", "config": {"name": "GCN", "inputs": [{"shape": [null, 41], "name": "node_attributes", "dtype": "float32"}, {"shape": [null, 1], "name": "edge_weights", "dtype": "float32"}, {"shape": [null, 2], "name": "edge_indices", "dtype": "int64"}, {"shape": [], "name": "total_nodes", "dtype": "int64"}, {"shape": [], "name": "total_edges", "dtype": "int64"}], "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 25, "output_dim": 1}, "gcn_args": {"units": 140, "use_bias": true, "activation": "relu"}, "depth": 5, "verbose": 10, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true, false], "units": [140, 70, 1], "activation": ["relu", "relu", "linear"]}}}, "training": {"fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.001, "learning_rate_stop": 5e-05, "epo_min": 250, "epo": 800, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.001}}, "loss": "mean_absolute_error"}, "cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "scaler": {"class_name": "StandardScaler", "config": {"with_std": true, "with_mean": true, "copy": true}}}, "dataset": {"class_name": "ESOLDataset", "module_name": "kgcnn.data.datasets.ESOLDataset", "config": {}, "methods": [{"set_attributes": {}}, {"map_list": {"method": "normalize_edge_weights_sym"}}, {"map_list": {"method": "count_nodes_and_edges"}}]}, "data": {"data_unit": "mol/L"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "4.0.0"}}
{"model": {"class_name": "make_model", "module_name": "kgcnn.literature_core.GCN", "config": {"name": "GCN", "inputs": [{"shape": [null, 41], "name": "node_attributes", "dtype": "float32"}, {"shape": [null, 1], "name": "edge_weights", "dtype": "float32"}, {"shape": [null, 2], "name": "edge_indices", "dtype": "int64"}, {"shape": [], "name": "total_nodes", "dtype": "int64"}, {"shape": [], "name": "total_edges", "dtype": "int64"}], "cast_disjoint_kwargs": {"padded_disjoint": true}, "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 25, "output_dim": 1}, "gcn_args": {"units": 140, "use_bias": true, "activation": "relu"}, "depth": 0, "verbose": 10, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true, false], "units": [140, 70, 1], "activation": ["relu", "relu", "linear"]}}}, "training": {"fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.001, "learning_rate_stop": 5e-05, "epo_min": 250, "epo": 800, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.001}}, "loss": "mean_absolute_error"}, "cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "scaler": {"class_name": "StandardScaler", "config": {"with_std": true, "with_mean": true, "copy": true}}}, "dataset": {"class_name": "ESOLDataset", "module_name": "kgcnn.data.datasets.ESOLDataset", "config": {}, "methods": [{"set_attributes": {}}, {"map_list": {"method": "normalize_edge_weights_sym"}}, {"map_list": {"method": "count_nodes_and_edges"}}]}, "data": {"data_unit": "mol/L"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "4.0.0"}}

0 comments on commit 7e57f4c

Please sign in to comment.