From cf94114439ae2dce983a00c345087a281da9d421 Mon Sep 17 00:00:00 2001 From: PatReis Date: Wed, 13 Sep 2023 17:57:01 +0200 Subject: [PATCH] continue keras core integration --- .../Schnet/Schnet_ESOLDataset_score.yaml | 200 +++++++++--------- .../ESOLDataset/Schnet/Schnet_hyper.json | 2 +- training_core/train_graph.py | 2 +- 3 files changed, 104 insertions(+), 100 deletions(-) diff --git a/training_core/results/ESOLDataset/Schnet/Schnet_ESOLDataset_score.yaml b/training_core/results/ESOLDataset/Schnet/Schnet_ESOLDataset_score.yaml index 0c497f7a..940ad61f 100644 --- a/training_core/results/ESOLDataset/Schnet/Schnet_ESOLDataset_score.yaml +++ b/training_core/results/ESOLDataset/Schnet/Schnet_ESOLDataset_score.yaml @@ -1,7 +1,11 @@ -OS: posix_linux -backend: tensorflow +OS: nt_win32 +backend: torch +cuda_available: 'True' data_unit: '' -date_time: '2023-09-08 16:32:16' +date_time: '2023-09-13 17:56:04' +device_id: '[0]' +device_memory: '[{''allocated'': 0.0, ''cached'': 0.1}]' +device_name: '[''NVIDIA GeForce GTX 1060 6GB'']' epochs: - 800 - 800 @@ -17,11 +21,11 @@ learning_rate: - 1.06999996205559e-05 - 1.06999996205559e-05 loss: -- 0.012356339953839779 -- 0.013351923786103725 -- 0.013529826886951923 -- 0.013422749936580658 -- 0.013661202043294907 +- 0.010011430829763412 +- 0.012302489019930363 +- 0.013542941771447659 +- 0.01063549891114235 +- 0.009813697077333927 max_learning_rate: - 0.0005000000237487257 - 0.0005000000237487257 @@ -29,41 +33,41 @@ max_learning_rate: - 0.0005000000237487257 - 0.0005000000237487257 max_loss: -- 0.7113643884658813 -- 0.7186248898506165 -- 0.7084435224533081 -- 0.7332432270050049 -- 0.706541895866394 +- 0.7063413262367249 +- 0.7609256505966187 +- 0.7159409523010254 +- 0.7039654850959778 +- 0.7093842625617981 max_scaled_mean_absolute_error: -- 1.4824234247207642 -- 1.5330605506896973 -- 1.5074206590652466 -- 1.5363695621490479 -- 1.503669261932373 +- 1.4892581701278687 +- 1.6069437265396118 +- 1.5262893438339233 +- 1.454487919807434 +- 1.4871799945831299 max_scaled_root_mean_squared_error: -- 1.8612427711486816 -- 1.9048391580581665 -- 1.8831799030303955 -- 2.121493101119995 -- 1.863604187965393 +- 1.8569190502166748 +- 1.9821991920471191 +- 2.953916311264038 +- 2.477454900741577 +- 1.8458967208862305 max_val_loss: -- 0.5424761772155762 -- 0.40656977891921997 -- 0.4068171977996826 -- 0.4622005820274353 -- 0.350281298160553 +- 0.3586343824863434 +- 0.48906615376472473 +- 0.46746182441711426 +- 0.4514211118221283 +- 0.4105031490325928 max_val_scaled_mean_absolute_error: -- 0.9954290390014648 -- 0.873090922832489 -- 0.7997042536735535 -- 0.8752737641334534 -- 0.834363579750061 +- 0.7646477818489075 +- 1.045943021774292 +- 1.0663926601409912 +- 0.880517303943634 +- 0.8691794276237488 max_val_scaled_root_mean_squared_error: -- 1.279673457145691 -- 1.077497959136963 -- 1.0372447967529297 -- 1.1205958127975464 -- 1.1094077825546265 +- 1.01933753490448 +- 1.292637825012207 +- 1.3106887340545654 +- 1.1120004653930664 +- 1.114079236984253 min_learning_rate: - 1.06999996205559e-05 - 1.06999996205559e-05 @@ -71,80 +75,80 @@ min_learning_rate: - 1.06999996205559e-05 - 1.06999996205559e-05 min_loss: -- 0.011985968798398972 -- 0.013351923786103725 -- 0.013529826886951923 -- 0.01335224136710167 -- 0.013466808013617992 +- 0.010011430829763412 +- 0.012302489019930363 +- 0.013271925039589405 +- 0.01063549891114235 +- 0.009813697077333927 min_scaled_mean_absolute_error: -- 0.025462720543146133 -- 0.02883092314004898 -- 0.02940377965569496 -- 0.028466414660215378 -- 0.02892298437654972 +- 0.02106151171028614 +- 0.026483522728085518 +- 0.028707485646009445 +- 0.02257038839161396 +- 0.021056318655610085 min_scaled_root_mean_squared_error: -- 0.07428765296936035 -- 0.08556132763624191 -- 0.08565159142017365 -- 0.09151832014322281 -- 0.08853696286678314 +- 0.07587568461894989 +- 0.09088648855686188 +- 0.09705951809883118 +- 0.0869671180844307 +- 0.07382410019636154 min_val_loss: -- 0.23827065527439117 -- 0.20599853992462158 -- 0.19142144918441772 -- 0.20846903324127197 -- 0.2028113603591919 +- 0.20171000063419342 +- 0.19535136222839355 +- 0.18761269748210907 +- 0.20568722486495972 +- 0.19795897603034973 min_val_scaled_mean_absolute_error: -- 0.4522663950920105 -- 0.4033041298389435 -- 0.40948936343193054 -- 0.48299241065979004 -- 0.4422135353088379 +- 0.4149721562862396 +- 0.3942352831363678 +- 0.41914138197898865 +- 0.48157763481140137 +- 0.43868088722229004 min_val_scaled_root_mean_squared_error: -- 0.7140107750892639 -- 0.5469481945037842 -- 0.5934625267982483 -- 0.6652877926826477 -- 0.6266308426856995 +- 0.6290794610977173 +- 0.5576820373535156 +- 0.618284285068512 +- 0.6716180443763733 +- 0.5965774655342102 model_class: make_model model_name: Schnet model_version: 2023.09.07 multi_target_indices: null number_histories: 5 scaled_mean_absolute_error: -- 0.02621050924062729 -- 0.02883092314004898 -- 0.02940377965569496 -- 0.028566351160407066 -- 0.02940848283469677 +- 0.02106151171028614 +- 0.026483522728085518 +- 0.029000014066696167 +- 0.02257038839161396 +- 0.021056318655610085 scaled_root_mean_squared_error: -- 0.07438243925571442 -- 0.08620187640190125 -- 0.08658499270677567 -- 0.09151832014322281 -- 0.08886461704969406 +- 0.07620621472597122 +- 0.091074638068676 +- 0.09821934998035431 +- 0.08720821142196655 +- 0.07382410019636154 seed: 42 time_list: -- '0:04:27.346256' -- '0:04:26.197451' -- '0:04:30.127993' -- '0:04:31.160348' -- '0:04:32.070582' +- '0:20:17.603629' +- '0:19:09.172761' +- '0:21:18.667718' +- '0:20:18.305153' +- '0:19:49.705409' val_loss: -- 0.2750350832939148 -- 0.22051694989204407 -- 0.23590177297592163 -- 0.21229249238967896 -- 0.2264338731765747 +- 0.2412828654050827 +- 0.2220507264137268 +- 0.23384696245193481 +- 0.2138018012046814 +- 0.2256581038236618 val_scaled_mean_absolute_error: -- 0.48352858424186707 -- 0.4216223955154419 -- 0.4384129047393799 -- 0.4986894130706787 -- 0.4725487530231476 +- 0.4590207040309906 +- 0.4284597635269165 +- 0.46461614966392517 +- 0.49751242995262146 +- 0.4585217833518982 val_scaled_root_mean_squared_error: -- 0.7575325965881348 -- 0.583825409412384 -- 0.607445240020752 -- 0.7178400754928589 -- 0.6797709465026855 +- 0.6878407001495361 +- 0.5926348567008972 +- 0.680878758430481 +- 0.6880843043327332 +- 0.6271232962608337 diff --git a/training_core/results/ESOLDataset/Schnet/Schnet_hyper.json b/training_core/results/ESOLDataset/Schnet/Schnet_hyper.json index f93be33e..b33d262e 100644 --- a/training_core/results/ESOLDataset/Schnet/Schnet_hyper.json +++ b/training_core/results/ESOLDataset/Schnet/Schnet_hyper.json @@ -1 +1 @@ -{"model": {"class_name": "make_model", "module_name": "kgcnn.literature_core.Schnet", "config": {"name": "Schnet", "inputs": [{"shape": [null], "name": "node_number", "dtype": "int32"}, {"shape": [null, 3], "name": "node_coordinates", "dtype": "float32"}, {"shape": [null, 2], "name": "range_indices", "dtype": "int64"}, {"shape": [], "name": "total_nodes", "dtype": "int64"}, {"shape": [], "name": "total_ranges", "dtype": "int64"}], "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true], "units": [64, 1], "activation": ["kgcnn>shifted_softplus", "linear"]}, "last_mlp": {"use_bias": [true, true], "units": [128, 64], "activation": ["kgcnn>shifted_softplus", "kgcnn>shifted_softplus"]}, "interaction_args": {"units": 128, "use_bias": true, "activation": "kgcnn>shifted_softplus", "cfconv_pool": "scatter_sum"}, "node_pooling_args": {"pooling_method": "scatter_sum"}, "depth": 4, "gauss_args": {"bins": 20, "distance": 4, "offset": 0.0, "sigma": 0.4}, "verbose": 10}}, "training": {"cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "scaler": {"class_name": "StandardScaler", "config": {"with_std": true, "with_mean": true, "copy": true}}, "fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.0005, "learning_rate_stop": 1e-05, "epo_min": 100, "epo": 800, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.0005}}, "loss": "mean_absolute_error"}}, "data": {}, "dataset": {"class_name": "ESOLDataset", "module_name": "kgcnn.data.datasets.ESOLDataset", "config": {}, "methods": [{"set_attributes": {}}, {"map_list": {"method": "set_range", "max_distance": 4, "max_neighbours": 10000}}, {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", "count_edges": "range_indices"}}]}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "4.0.0"}} \ No newline at end of file +{"model": {"class_name": "make_model", "module_name": "kgcnn.literature_core.Schnet", "config": {"name": "Schnet", "inputs": [{"shape": [null], "name": "node_number", "dtype": "int32"}, {"shape": [null, 3], "name": "node_coordinates", "dtype": "float32"}, {"shape": [null, 2], "name": "range_indices", "dtype": "int64"}, {"shape": [], "name": "total_nodes", "dtype": "int64"}, {"shape": [], "name": "total_ranges", "dtype": "int64"}], "cast_disjoint_kwargs": {"padded_disjoint": false}, "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true], "units": [64, 1], "activation": ["kgcnn>shifted_softplus", "linear"]}, "last_mlp": {"use_bias": [true, true], "units": [128, 64], "activation": ["kgcnn>shifted_softplus", "kgcnn>shifted_softplus"]}, "interaction_args": {"units": 128, "use_bias": true, "activation": "kgcnn>shifted_softplus", "cfconv_pool": "scatter_sum"}, "node_pooling_args": {"pooling_method": "scatter_sum"}, "depth": 4, "gauss_args": {"bins": 20, "distance": 4, "offset": 0.0, "sigma": 0.4}, "verbose": 10}}, "training": {"cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "scaler": {"class_name": "StandardScaler", "config": {"with_std": true, "with_mean": true, "copy": true}}, "fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.0005, "learning_rate_stop": 1e-05, "epo_min": 100, "epo": 800, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.0005}}, "loss": "mean_absolute_error"}}, "data": {}, "dataset": {"class_name": "ESOLDataset", "module_name": "kgcnn.data.datasets.ESOLDataset", "config": {}, "methods": [{"set_attributes": {}}, {"map_list": {"method": "set_range", "max_distance": 4, "max_neighbours": 10000}}, {"map_list": {"method": "count_nodes_and_edges", "total_edges": "total_ranges", "count_edges": "range_indices"}}]}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "4.0.0"}} \ No newline at end of file diff --git a/training_core/train_graph.py b/training_core/train_graph.py index 169fb6c9..108a4825 100644 --- a/training_core/train_graph.py +++ b/training_core/train_graph.py @@ -21,7 +21,7 @@ parser = argparse.ArgumentParser(description='Train a GNN on a Molecule dataset.') parser.add_argument("--hyper", required=False, help="Filepath to hyperparameter config file (.py or .json).", default="hyper/hyper_esol.py") -parser.add_argument("--category", required=False, help="Graph model to train.", default="GCN") +parser.add_argument("--category", required=False, help="Graph model to train.", default="Schnet") parser.add_argument("--model", required=False, help="Graph model to train.", default=None) parser.add_argument("--dataset", required=False, help="Name of the dataset.", default=None) parser.add_argument("--make", required=False, help="Name of the class for model.", default=None)