Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 5, 2023
1 parent 6cb727d commit 1bf70ae
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 97 deletions.
12 changes: 12 additions & 0 deletions kgcnn/backend/tensorflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import tensorflow as tf


def scatter_min(indices, values, shape):
return tf.scatter_min(indices, values, shape)

def scatter_max(indices, values, shape):
return tf.scatter_nd(indices, values, shape)

def scatter_mean(indices, values, shape):

return tf.scatter_nd(indices, values, shape)
46 changes: 46 additions & 0 deletions kgcnn/ops_core/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from kgcnn import backend
from keras_core.backend import KerasTensor
from keras_core.backend import any_symbolic_tensors
from keras_core.ops.operation import Operation


class ScatterMax(Operation):
def call(self, indices, values, shape):
return backend.scatter_max(indices, values, shape)

def compute_output_spec(self, indices, values, shape):
return KerasTensor(shape, dtype=values.dtype)


def scatter_max(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return ScatterMax().symbolic_call(indices, values, shape)
return backend.scatter_max(indices, values, shape)


class ScatterMin(Operation):
def call(self, indices, values, shape):
return backend.scatter_min(indices, values, shape)

def compute_output_spec(self, indices, values, shape):
return KerasTensor(shape, dtype=values.dtype)


def scatter_min(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return ScatterMin().symbolic_call(indices, values, shape)
return backend.scatter_min(indices, values, shape)


class ScatterMean(Operation):
def call(self, indices, values, shape):
return backend.scatter_mean(indices, values, shape)

def compute_output_spec(self, indices, values, shape):
return KerasTensor(shape, dtype=values.dtype)


def scatter_mean(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return ScatterMean().symbolic_call(indices, values, shape)
return backend.scatter_mean(indices, values, shape)
192 changes: 96 additions & 96 deletions training_core/results/ESOLDataset/GCN/GCN_ESOLDataset_score.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
data_unit: mol/L
date_time: '2023-09-04 13:32:26'
date_time: '2023-09-05 13:02:37'
epochs:
- 800
- 800
Expand All @@ -15,134 +15,134 @@ learning_rate:
- 5.1727271056734025e-05
- 5.1727271056734025e-05
loss:
- 0.02065335027873516
- 0.03046582266688347
- 0.026986178010702133
- 0.035748980939388275
- 0.025066649541258812
- 0.030111245810985565
- 0.029905715957283974
- 0.02666669711470604
- 0.03320549055933952
- 0.027655862271785736
max_learning_rate:
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
- 0.0010000000474974513
max_loss:
- 0.7785677909851074
- 0.7740250825881958
- 0.7712702751159668
- 0.7745649814605713
- 0.785307765007019
- 0.7727673053741455
- 0.766106128692627
- 0.7710932493209839
- 0.7851217985153198
- 0.7929884791374207
max_scaled_mean_absolute_error:
- 1.630502700805664
- 1.6439743041992188
- 1.6447334289550781
- 1.6330018043518066
- 1.6777045726776123
- 1.6122925281524658
- 1.6240359544754028
- 1.64657461643219
- 1.6524299383163452
- 1.694350242614746
max_scaled_root_mean_squared_error:
- 2.051668882369995
- 2.079864501953125
- 2.104031801223755
- 2.0745174884796143
- 2.1795969009399414
- 2.045316457748413
- 2.058136224746704
- 2.1218948364257812
- 2.086286783218384
- 2.2071051597595215
max_val_loss:
- 0.38604891300201416
- 0.4033467471599579
- 0.43952175974845886
- 0.4190915822982788
- 0.4452379643917084
- 0.4188742935657501
- 0.3975107669830322
- 0.36188578605651855
- 0.4290943443775177
- 0.3819858431816101
max_val_scaled_mean_absolute_error:
- 0.7995112538337708
- 0.9048331379890442
- 0.7838926911354065
- 0.7904984951019287
- 0.8072317838668823
- 0.8083961009979248
- 0.8527750968933105
- 0.680066704750061
- 0.8364616632461548
- 0.7286641001701355
max_val_scaled_root_mean_squared_error:
- 1.0650511980056763
- 1.1253228187561035
- 1.0860916376113892
- 1.0380667448043823
- 1.0352442264556885
- 1.0951517820358276
- 1.0844537019729614
- 0.9051862955093384
- 1.0477832555770874
- 0.9376411437988281
min_learning_rate:
- 5.1727271056734025e-05
- 5.1727271056734025e-05
- 5.1727271056734025e-05
- 5.1727271056734025e-05
- 5.1727271056734025e-05
min_loss:
- 0.020218532532453537
- 0.025095166638493538
- 0.026986178010702133
- 0.027601951733231544
- 0.025066649541258812
- 0.022370334714651108
- 0.02377462387084961
- 0.026008524000644684
- 0.02583014965057373
- 0.027655862271785736
min_scaled_mean_absolute_error:
- 0.04289504885673523
- 0.05423098802566528
- 0.05875982344150543
- 0.058775320649147034
- 0.0540020689368248
- 0.04759758338332176
- 0.05133747309446335
- 0.05656632035970688
- 0.05493299663066864
- 0.05945947393774986
min_scaled_root_mean_squared_error:
- 0.12030639499425888
- 0.15153780579566956
- 0.1605674773454666
- 0.16291119158267975
- 0.15338918566703796
- 0.12559226155281067
- 0.14683835208415985
- 0.16013726592063904
- 0.15730111300945282
- 0.163068026304245
min_val_loss:
- 0.22336530685424805
- 0.2090274691581726
- 0.18527281284332275
- 0.20915909111499786
- 0.27025991678237915
- 0.22113201022148132
- 0.19192928075790405
- 0.20304016768932343
- 0.19797971844673157
- 0.26365381479263306
min_val_scaled_mean_absolute_error:
- 0.46149882674217224
- 0.4658408463001251
- 0.4380309283733368
- 0.4883240759372711
- 0.456197589635849
- 0.45169585943222046
- 0.4429832398891449
- 0.4455697536468506
- 0.4632081687450409
- 0.46592286229133606
min_val_scaled_root_mean_squared_error:
- 0.6655771136283875
- 0.6615913510322571
- 0.6196939945220947
- 0.6811450123786926
- 0.6317007541656494
- 0.6493441462516785
- 0.6086452007293701
- 0.6509482264518738
- 0.6738059520721436
- 0.6359040141105652
model_class: make_model
model_name: GCN
model_version: 2023.09.30
multi_target_indices: null
number_histories: 5
scaled_mean_absolute_error:
- 0.04376649111509323
- 0.06548744440078735
- 0.05875982344150543
- 0.07592686265707016
- 0.0540020689368248
- 0.06371089071035385
- 0.06444712728261948
- 0.0579991340637207
- 0.07066909223794937
- 0.05945947393774986
scaled_root_mean_squared_error:
- 0.12186748534440994
- 0.15425890684127808
- 0.1605674773454666
- 0.16617444157600403
- 0.15425477921962738
- 0.13027626276016235
- 0.15026293694972992
- 0.16108879446983337
- 0.1621994823217392
- 0.163068026304245
seed: 42
time_list:
- '0:02:41.250184'
- '0:02:48.877327'
- '0:02:45.440704'
- '0:02:51.494117'
- '0:02:54.385403'
- '0:06:16.883915'
- '0:06:34.529023'
- '0:07:15.199847'
- '0:10:23.172025'
- '0:10:06.016519'
val_loss:
- 0.22336530685424805
- 0.2319173663854599
- 0.22124597430229187
- 0.21199733018875122
- 0.30064311623573303
- 0.23328134417533875
- 0.21415188908576965
- 0.21731525659561157
- 0.20350980758666992
- 0.2738124430179596
val_scaled_mean_absolute_error:
- 0.463508278131485
- 0.5030187964439392
- 0.4640654921531677
- 0.4997897446155548
- 0.48311617970466614
- 0.46819519996643066
- 0.4745296835899353
- 0.4473111033439636
- 0.47483858466148376
- 0.5058296918869019
val_scaled_root_mean_squared_error:
- 0.666396975517273
- 0.7374665141105652
- 0.674218475818634
- 0.694538414478302
- 0.7118412852287292
- 0.6607564687728882
- 0.6795703172683716
- 0.674023449420929
- 0.6981791853904724
- 0.7606579065322876
2 changes: 1 addition & 1 deletion training_core/results/ESOLDataset/GCN/GCN_hyper.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"model": {"class_name": "make_model", "module_name": "kgcnn.literature_core.GCN", "config": {"name": "GCN", "inputs": [{"shape": [null, 41], "name": "node_attributes", "dtype": "float32"}, {"shape": [null, 1], "name": "edge_weights", "dtype": "float32"}, {"shape": [null, 2], "name": "edge_indices", "dtype": "int64"}, {"shape": [], "name": "graph_size", "dtype": "int64"}, {"shape": [], "name": "edge_count", "dtype": "int64"}], "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 25, "output_dim": 1}, "gcn_args": {"units": 140, "use_bias": true, "activation": "relu"}, "depth": 5, "verbose": 10, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true, false], "units": [140, 70, 1], "activation": ["relu", "relu", "linear"]}}}, "training": {"fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.001, "learning_rate_stop": 5e-05, "epo_min": 250, "epo": 800, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.001}}, "loss": "mean_absolute_error", "jit_compile": false}, "cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "scaler": {"class_name": "StandardScaler", "config": {"with_std": true, "with_mean": true, "copy": true}}}, "dataset": {"class_name": "ESOLDataset", "module_name": "kgcnn.data.datasets.ESOLDataset", "config": {}, "methods": [{"set_attributes": {}}, {"map_list": {"method": "normalize_edge_weights_sym"}}]}, "data": {"data_unit": "mol/L"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "2.0.3"}}
{"model": {"class_name": "make_model", "module_name": "kgcnn.literature_core.GCN", "config": {"name": "GCN", "inputs": [{"shape": [null, 41], "name": "node_attributes", "dtype": "float32"}, {"shape": [null, 1], "name": "edge_weights", "dtype": "float32"}, {"shape": [null, 2], "name": "edge_indices", "dtype": "int64"}, {"shape": [], "name": "graph_size", "dtype": "int64"}, {"shape": [], "name": "edge_count", "dtype": "int64"}], "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 25, "output_dim": 1}, "gcn_args": {"units": 140, "use_bias": true, "activation": "relu"}, "depth": 5, "verbose": 10, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true, false], "units": [140, 70, 1], "activation": ["relu", "relu", "linear"]}}}, "training": {"fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.001, "learning_rate_stop": 5e-05, "epo_min": 250, "epo": 800, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.001}}, "loss": "mean_absolute_error", "jit_compile": true}, "cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "scaler": {"class_name": "StandardScaler", "config": {"with_std": true, "with_mean": true, "copy": true}}}, "dataset": {"class_name": "ESOLDataset", "module_name": "kgcnn.data.datasets.ESOLDataset", "config": {}, "methods": [{"set_attributes": {}}, {"map_list": {"method": "normalize_edge_weights_sym"}}]}, "data": {"data_unit": "mol/L"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "2.0.3"}}

0 comments on commit 1bf70ae

Please sign in to comment.