diff --git a/kgcnn/backend/_tensorflow.py b/kgcnn/backend/_tensorflow.py index 235a9f62..ce618e17 100644 --- a/kgcnn/backend/_tensorflow.py +++ b/kgcnn/backend/_tensorflow.py @@ -21,7 +21,7 @@ def scatter_reduce_max(indices, values, shape): def scatter_reduce_mean(indices, values, shape): indices = tf.expand_dims(indices, axis=-1) counts = tf.scatter_nd(indices, tf.ones_like(values), shape) - return tf.scatter_nd(indices, values, shape)/counts + return tf.math.divide_no_nan(tf.scatter_nd(indices, values, shape), counts) def scatter_reduce_softmax(indices, values, shape, normalize: bool = False): diff --git a/kgcnn/literature/GIN/_make.py b/kgcnn/literature/GIN/_make.py index fe457d96..b89bfa59 100644 --- a/kgcnn/literature/GIN/_make.py +++ b/kgcnn/literature/GIN/_make.py @@ -108,7 +108,7 @@ def make_model(inputs: list = None, # Wrapping disjoint model. out = model_disjoint( - inputs, + [n, disjoint_indices, batch_id_node, count_nodes], use_node_embedding=len(inputs[0]['shape']) < 2, input_node_embedding=input_node_embedding, depth=depth, gin_args=gin_args, gin_mlp=gin_mlp, last_mlp=last_mlp, dropout=dropout, output_embedding=output_embedding, output_mlp=output_mlp diff --git a/kgcnn/literature/GraphSAGE/_make.py b/kgcnn/literature/GraphSAGE/_make.py index da811e4b..2ce140fa 100644 --- a/kgcnn/literature/GraphSAGE/_make.py +++ b/kgcnn/literature/GraphSAGE/_make.py @@ -30,14 +30,14 @@ model_default = { 'name': "GraphSAGE", 'inputs': [ - {"shape": (None, ), "name": "node_attributes", "dtype": "float32"}, - {"shape": (None, ), "name": "edge_attributes", "dtype": "float32"}, + {"shape": (None,), "name": "node_attributes", "dtype": "float32"}, + {"shape": (None,), "name": "edge_attributes", "dtype": "float32"}, {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, {"shape": (), "name": "total_nodes", "dtype": "int64"}, {"shape": (), "name": "total_edges", "dtype": "int64"} ], "cast_disjoint_kwargs": {}, - "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, 'node_mlp_args': {"units": [100, 50], "use_bias": True, "activation": ['relu', "linear"]}, 'edge_mlp_args': {"units": [100, 50], "use_bias": True, "activation": ['relu', "linear"]}, @@ -121,48 +121,25 @@ def make_model(inputs: list = None, **cast_disjoint_kwargs)([batched_nodes, batched_indices, total_nodes, total_edges]) ed, _, _, _ = CastBatchedAttributesToDisjoint(**cast_disjoint_kwargs)([batched_edges, total_edges]) - # Embedding, if no feature dimension - if len(inputs[0]['shape']) < 2: - n = Embedding(**input_node_embedding)(n) - if len(inputs[1]['shape']) < 2: - ed = Embedding(**input_edge_embedding)(ed) - - for i in range(0, depth): - - eu = GatherNodesOutgoing(**gather_args)([n, disjoint_indices]) - if use_edge_features: - eu = Concatenate(**concat_args)([eu, ed]) - - eu = GraphMLP(**edge_mlp_args)([eu, batch_id_edge, count_edges]) - - # Pool message - if pooling_args['pooling_method'] in ["LSTM", "lstm"]: - nu = AggregateLocalEdgesLSTM(**pooling_args)([n, eu, disjoint_indices]) - else: - nu = AggregateLocalEdges(**pooling_args)([n, eu, disjoint_indices]) # Summing for each node connection - - nu = Concatenate(**concat_args)([n, nu]) # Concatenate node features with new edge updates - - n = GraphMLP(**node_mlp_args)([nu, batch_id_node, count_nodes]) - - n = GraphLayerNormalization()([n, batch_id_node, count_nodes]) + out = model_disjoint( + [n, ed, disjoint_indices, batch_id_node, batch_id_edge, count_nodes, count_edges], + use_node_embedding=len(inputs[0]['shape']) < 2, use_edge_embedding=len(inputs[1]['shape']) < 2, + input_node_embedding=input_node_embedding, input_edge_embedding=input_edge_embedding, + node_mlp_args=node_mlp_args, edge_mlp_args=edge_mlp_args, pooling_args=pooling_args, + pooling_nodes_args=pooling_nodes_args, gather_args=gather_args, concat_args=concat_args, + use_edge_features=use_edge_features, depth=depth, output_embedding=output_embedding, + output_mlp=output_mlp, + ) # Regression layer on output if output_embedding == 'graph': - out = PoolingNodes(**pooling_nodes_args)([count_nodes, n, batch_id_node]) - out = MLP(**output_mlp)(out) out = CastDisjointToGraphState(**cast_disjoint_kwargs)(out) - elif output_embedding == 'node': - out = GraphMLP(**output_mlp)([n, batch_id_node, count_nodes]) if output_to_tensor: out = CastDisjointToBatchedAttributes(**cast_disjoint_kwargs)([batched_nodes, out, batch_id_node, node_id]) else: out = CastDisjointToGraphState(**cast_disjoint_kwargs)(out) - else: - raise ValueError("Unsupported output embedding for `GraphSAGE`") - if output_scaling is not None: scaler = get_scaler(output_scaling["name"])(**output_scaling) out = scaler(out) @@ -173,6 +150,7 @@ def make_model(inputs: list = None, if output_scaling is not None: def set_scale(*args, **kwargs): scaler.set_scale(*args, **kwargs) + setattr(model, "set_scale", set_scale) return model @@ -194,7 +172,7 @@ def model_disjoint( depth: int = None, output_embedding: str = None, output_mlp: dict = None, - ): +): n, ed, disjoint_indices, batch_id_node, batch_id_edge, count_nodes, count_edges = inputs # Embedding, if no feature dimension @@ -231,4 +209,4 @@ def model_disjoint( out = GraphMLP(**output_mlp)([n, batch_id_node, count_nodes]) else: raise ValueError("Unsupported output embedding for `GraphSAGE`") - return out \ No newline at end of file + return out diff --git a/training/hyper/hyper_clintox.py b/training/hyper/hyper_clintox.py index 86f46015..df360e91 100644 --- a/training/hyper/hyper_clintox.py +++ b/training/hyper/hyper_clintox.py @@ -41,7 +41,6 @@ "loss": "binary_crossentropy", "metrics": ["binary_accuracy", {"class_name": "AUC", "config": {"name": "auc"}}] }, - "multi_target_indices": None }, "dataset": { "class_name": "ClinToxDataset", @@ -393,4 +392,78 @@ "kgcnn_version": "4.0.0" } }, + "DMPNN": { + "model": { + "class_name": "make_model", + "module_name": "kgcnn.literature.DMPNN", + "config": { + "name": "DMPNN", + "inputs": [ + {"shape": (None, 41), "name": "node_attributes", "dtype": "float32"}, + {"shape": (None, 11), "name": "edge_attributes", "dtype": "float32"}, + {"shape": (None, 2), "name": "edge_indices", "dtype": "int64"}, + {"shape": (None, 1), "name": "edge_indices_reverse", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"} + ], + "cast_disjoint_kwargs": {}, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, + "input_graph_embedding": {"input_dim": 100, "output_dim": 64}, + "pooling_args": {"pooling_method": "scatter_sum"}, + "edge_initialize": {"units": 128, "use_bias": True, "activation": "relu"}, + "edge_dense": {"units": 128, "use_bias": True, "activation": "linear"}, + "edge_activation": {"activation": "relu"}, + "node_dense": {"units": 128, "use_bias": True, "activation": "relu"}, + "verbose": 10, "depth": 5, + "dropout": {"rate": 0.1}, + "output_embedding": "graph", + "output_mlp": { + "use_bias": [True, True, False], "units": [64, 32, 1], + "activation": ["relu", "relu", "sigmoid"] + } + } + }, + "training": { + "fit": {"batch_size": 32, "epochs": 50, "validation_freq": 1, "verbose": 2, "callbacks": []}, + "compile": { + "optimizer": { + "class_name": "Adam", + "config": { + "learning_rate": + {"module": "keras_core.optimizers.schedules", + "class_name": "ExponentialDecay", + "config": {"initial_learning_rate": 0.001, + "decay_steps": 1600, + "decay_rate": 0.5, "staircase": False}} + } + }, + # "loss": "kgcnn>BinaryCrossentropyNoNaN", + # "metrics": ["kgcnn>BinaryAccuracyNoNaN", + # {"class_name": "kgcnn>AUCNoNaN", "config": {"multi_label": True, "num_labels": 12}}], + # "metrics": ["kgcnn>BinaryAccuracyNoNaN", "kgcnn>AUCNoNaN"], + "loss": "binary_crossentropy", + "metrics": ["binary_accuracy", {"class_name": "AUC", "config": {"name": "auc"}}] + } + }, + "dataset": { + "class_name": "ClinToxDataset", + "module_name": "kgcnn.data.datasets.ClinToxDataset", + "config": {}, + "methods": [ + {"set_attributes": {}}, + {"set_train_test_indices_k_fold": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + {"map_list": {"method": "set_edge_indices_reverse"}}, + {"map_list": {"method": "count_nodes_and_edges"}}, + ] + }, + "data": { + "data_unit": "mol/L" + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "4.0.0" + } + }, } \ No newline at end of file diff --git a/training/hyper/hyper_cora_lu.py b/training/hyper/hyper_cora_lu.py index 4140e5fb..72cb648f 100644 --- a/training/hyper/hyper_cora_lu.py +++ b/training/hyper/hyper_cora_lu.py @@ -328,4 +328,78 @@ "kgcnn_version": "4.0.0" } }, + "DMPNN": { + "model": { + "class_name": "make_model", + "module_name": "kgcnn.literature.DMPNN", + "config": { + "name": "DMPNN", + "inputs": [ + {"shape": [None, 1433], "name": "node_attributes", "dtype": "float32"}, + {"shape": [None, 1], "name": "edge_weights", "dtype": "float32"}, + {"shape": [None, 2], "name": "edge_indices", "dtype": "int64"}, + {"shape": (None, 1), "name": "edge_indices_reverse", "dtype": "int64"}, + {"shape": (), "name": "total_nodes", "dtype": "int64"}, + {"shape": (), "name": "total_edges", "dtype": "int64"} + ], + "cast_disjoint_kwargs": {}, + "input_node_embedding": {"input_dim": 95, "output_dim": 64}, + "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, + "input_graph_embedding": {"input_dim": 100, "output_dim": 64}, + "pooling_args": {"pooling_method": "scatter_sum"}, + "edge_initialize": {"units": 128, "use_bias": True, "activation": "relu"}, + "edge_dense": {"units": 128, "use_bias": True, "activation": "linear"}, + "edge_activation": {"activation": "relu"}, + "node_dense": {"units": 128, "use_bias": True, "activation": "relu"}, + "verbose": 10, "depth": 5, + "dropout": {"rate": 0.1}, + "output_embedding": "node", + "output_mlp": { + "use_bias": [True, True, False], "units": [64, 32, 7], + "activation": ["relu", "relu", "softmax"] + } + } + }, + "training": { + "cross_validation": {"class_name": "KFold", + "config": {"n_splits": 5, "random_state": 42, "shuffle": True}}, + "multi_target_indices": None, + "fit": {"batch_size": 32, "epochs": 300, "validation_freq": 1, "verbose": 2, "callbacks": []}, + "compile": { + "optimizer": { + "class_name": "Adam", + "config": { + "learning_rate": + {"module": "keras_core.optimizers.schedules", + "class_name": "ExponentialDecay", + "config": {"initial_learning_rate": 0.001, + "decay_steps": 1600, + "decay_rate": 0.5, "staircase": False}} + } + }, + "loss": "categorical_crossentropy", + "weighted_metrics": ["categorical_accuracy", {"class_name": "AUC", "config": {"name": "auc"}}] + }, + }, + "dataset": { + "class_name": "CoraLuDataset", + "module_name": "kgcnn.data.datasets.CoraLuDataset", + "config": {}, + "methods": [ + {"map_list": {"method": "make_undirected_edges"}}, + {"map_list": {"method": "add_edge_self_loops"}}, + {"map_list": {"method": "normalize_edge_weights_sym"}}, + {"map_list": {"method": "set_edge_indices_reverse"}}, + {"map_list": {"method": "count_nodes_and_edges"}}, + ] + }, + "data": { + "data_unit": "" + }, + "info": { + "postfix": "", + "postfix_file": "", + "kgcnn_version": "4.0.0" + } + }, } \ No newline at end of file diff --git a/training/results/ClinToxDataset/DMPNN/DMPNN_ClinToxDataset_score.yaml b/training/results/ClinToxDataset/DMPNN/DMPNN_ClinToxDataset_score.yaml new file mode 100644 index 00000000..5621566b --- /dev/null +++ b/training/results/ClinToxDataset/DMPNN/DMPNN_ClinToxDataset_score.yaml @@ -0,0 +1,136 @@ +OS: nt_win32 +auc: +- 0.9929494261741638 +- 0.9952705502510071 +- 0.9950457215309143 +- 0.9944009184837341 +- 0.9972527623176575 +backend: tensorflow +binary_accuracy: +- 0.974662184715271 +- 0.9805907011032104 +- 0.9847972989082336 +- 0.9830938577651978 +- 0.9847972989082336 +cuda_available: 'False' +data_unit: mol/L +date_time: '2023-09-28 18:29:42' +device_id: '[LogicalDevice(name=''/device:CPU:0'', device_type=''CPU'')]' +device_memory: '[]' +device_name: '[{}]' +epochs: +- 50 +- 50 +- 50 +- 50 +- 50 +execute_folds: null +kgcnn_version: 4.0.0 +loss: +- 0.05233962461352348 +- 0.04196326807141304 +- 0.04036809504032135 +- 0.04770557954907417 +- 0.033450283110141754 +max_auc: +- 0.9944294691085815 +- 0.9957322478294373 +- 0.9951406717300415 +- 0.9951319098472595 +- 0.9973369240760803 +max_binary_accuracy: +- 0.9797297120094299 +- 0.9831223487854004 +- 0.9847972989082336 +- 0.9830938577651978 +- 0.9881756901741028 +max_loss: +- 0.7650876045227051 +- 1.0989446640014648 +- 0.876670241355896 +- 1.5429741144180298 +- 1.1330821514129639 +max_val_auc: +- 0.8952381610870361 +- 0.8650503754615784 +- 0.9409422278404236 +- 0.9057533144950867 +- 0.857114315032959 +max_val_binary_accuracy: +- 0.9560810923576355 +- 0.9559321999549866 +- 0.9527027010917664 +- 0.9730639457702637 +- 0.9594594836235046 +max_val_loss: +- 0.43353766202926636 +- 0.7357405424118042 +- 0.4855177402496338 +- 0.4508592188358307 +- 0.4019092917442322 +min_auc: +- 0.4021470546722412 +- 0.4715109169483185 +- 0.48213809728622437 +- 0.46312111616134644 +- 0.44651639461517334 +min_binary_accuracy: +- 0.8826013803482056 +- 0.8742616176605225 +- 0.8944256901741028 +- 0.8393913507461548 +- 0.9130067825317383 +min_loss: +- 0.04891818389296532 +- 0.04087803512811661 +- 0.04036809504032135 +- 0.04770557954907417 +- 0.033450283110141754 +min_val_auc: +- 0.41168832778930664 +- 0.3587069809436798 +- 0.4413570463657379 +- 0.44262558221817017 +- 0.4436451196670532 +min_val_binary_accuracy: +- 0.9020270109176636 +- 0.9288135766983032 +- 0.8851351141929626 +- 0.7643097639083862 +- 0.9391891956329346 +min_val_loss: +- 0.19326889514923096 +- 0.2550320029258728 +- 0.13597355782985687 +- 0.10933977365493774 +- 0.1705658733844757 +model_class: make_model +model_name: DMPNN +model_version: '2023-09-26' +multi_target_indices: null +number_histories: 5 +seed: 42 +time_list: +- '0:01:10.809307' +- '0:01:10.385692' +- '0:01:19.683660' +- '0:01:33.278228' +- '0:01:47.670030' +val_auc: +- 0.8379220366477966 +- 0.7629475593566895 +- 0.9326476454734802 +- 0.8089637756347656 +- 0.8059552311897278 +val_binary_accuracy: +- 0.9493243098258972 +- 0.9491525292396545 +- 0.9222972989082336 +- 0.9629629850387573 +- 0.9560810923576355 +val_loss: +- 0.2909963130950928 +- 0.5557326078414917 +- 0.22096827626228333 +- 0.13852910697460175 +- 0.4019092917442322 diff --git a/training/results/ClinToxDataset/DMPNN/DMPNN_hyper.json b/training/results/ClinToxDataset/DMPNN/DMPNN_hyper.json new file mode 100644 index 00000000..e98f0a07 --- /dev/null +++ b/training/results/ClinToxDataset/DMPNN/DMPNN_hyper.json @@ -0,0 +1 @@ +{"model": {"class_name": "make_model", "module_name": "kgcnn.literature.DMPNN", "config": {"name": "DMPNN", "inputs": [{"shape": [null, 41], "name": "node_attributes", "dtype": "float32"}, {"shape": [null, 11], "name": "edge_attributes", "dtype": "float32"}, {"shape": [null, 2], "name": "edge_indices", "dtype": "int64"}, {"shape": [null, 1], "name": "edge_indices_reverse", "dtype": "int64"}, {"shape": [], "name": "total_nodes", "dtype": "int64"}, {"shape": [], "name": "total_edges", "dtype": "int64"}], "cast_disjoint_kwargs": {}, "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 5, "output_dim": 64}, "input_graph_embedding": {"input_dim": 100, "output_dim": 64}, "pooling_args": {"pooling_method": "scatter_sum"}, "edge_initialize": {"units": 128, "use_bias": true, "activation": "relu"}, "edge_dense": {"units": 128, "use_bias": true, "activation": "linear"}, "edge_activation": {"activation": "relu"}, "node_dense": {"units": 128, "use_bias": true, "activation": "relu"}, "verbose": 10, "depth": 5, "dropout": {"rate": 0.1}, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true, false], "units": [64, 32, 1], "activation": ["relu", "relu", "sigmoid"]}}}, "training": {"fit": {"batch_size": 32, "epochs": 50, "validation_freq": 1, "verbose": 2, "callbacks": []}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": {"module": "keras_core.optimizers.schedules", "class_name": "ExponentialDecay", "config": {"initial_learning_rate": 0.001, "decay_steps": 1600, "decay_rate": 0.5, "staircase": false}}}}, "loss": "binary_crossentropy", "metrics": ["binary_accuracy", {"class_name": "AUC", "config": {"name": "auc"}}]}}, "dataset": {"class_name": "ClinToxDataset", "module_name": "kgcnn.data.datasets.ClinToxDataset", "config": {}, "methods": [{"set_attributes": {}}, {"set_train_test_indices_k_fold": {"n_splits": 5, "random_state": 42, "shuffle": true}}, {"map_list": {"method": "set_edge_indices_reverse"}}, {"map_list": {"method": "count_nodes_and_edges"}}]}, "data": {"data_unit": "mol/L"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "4.0.0"}} \ No newline at end of file diff --git a/training/results/ClinToxDataset/GraphSAGE/GraphSAGE_ClinToxDataset_score.yaml b/training/results/ClinToxDataset/GraphSAGE/GraphSAGE_ClinToxDataset_score.yaml new file mode 100644 index 00000000..141602cd --- /dev/null +++ b/training/results/ClinToxDataset/GraphSAGE/GraphSAGE_ClinToxDataset_score.yaml @@ -0,0 +1,154 @@ +OS: nt_win32 +auc: +- 0.9994004964828491 +- 0.9992138147354126 +- 0.9997279644012451 +- 0.9992911219596863 +- 0.999356746673584 +backend: tensorflow +binary_accuracy: +- 0.9890202879905701 +- 0.9881856441497803 +- 0.9932432174682617 +- 0.9864750504493713 +- 0.9898648858070374 +cuda_available: 'False' +data_unit: '' +date_time: '2023-09-28 18:57:48' +device_id: '[LogicalDevice(name=''/device:CPU:0'', device_type=''CPU'')]' +device_memory: '[]' +device_name: '[{}]' +epochs: +- 100 +- 100 +- 100 +- 100 +- 100 +execute_folds: null +kgcnn_version: 4.0.0 +learning_rate: +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +loss: +- 0.014190061949193478 +- 0.012561065144836903 +- 0.00922443624585867 +- 0.013985068537294865 +- 0.014338349923491478 +max_auc: +- 0.9994129538536072 +- 0.9993635416030884 +- 0.9997595548629761 +- 0.9993796348571777 +- 0.9994169473648071 +max_binary_accuracy: +- 0.9915540814399719 +- 0.9932489395141602 +- 0.9957770109176636 +- 0.9932375550270081 +- 0.9923986196517944 +max_learning_rate: +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +max_loss: +- 0.4204205870628357 +- 0.27780216932296753 +- 0.4287554621696472 +- 0.27537333965301514 +- 0.2519896924495697 +max_val_auc: +- 0.9243290424346924 +- 0.8957247734069824 +- 0.9461678266525269 +- 0.8720597624778748 +- 0.8281375169754028 +max_val_binary_accuracy: +- 0.9493243098258972 +- 0.9525423645973206 +- 0.9493243098258972 +- 0.9696969985961914 +- 0.9527027010917664 +max_val_loss: +- 0.6296008229255676 +- 0.4612368643283844 +- 0.3711312711238861 +- 0.3351901173591614 +- 0.5245105028152466 +min_auc: +- 0.42854464054107666 +- 0.4985462427139282 +- 0.4686736762523651 +- 0.5199596881866455 +- 0.5527502298355103 +min_binary_accuracy: +- 0.8344594836235046 +- 0.9189873337745667 +- 0.8572635054588318 +- 0.9306846857070923 +- 0.9366554021835327 +min_learning_rate: +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +min_loss: +- 0.013901567086577415 +- 0.012501571327447891 +- 0.00910576619207859 +- 0.013965529389679432 +- 0.01430636178702116 +min_val_auc: +- 0.49748921394348145 +- 0.5072992444038391 +- 0.4549601972103119 +- 0.54942786693573 +- 0.5217825174331665 +min_val_binary_accuracy: +- 0.9020270109176636 +- 0.8983050584793091 +- 0.8581081032752991 +- 0.8888888955116272 +- 0.9155405163764954 +min_val_loss: +- 0.18837836384773254 +- 0.24647729098796844 +- 0.14736203849315643 +- 0.11688490211963654 +- 0.17578420042991638 +model_class: make_model +model_name: GraphSAGE +model_version: '2023-09-18' +multi_target_indices: null +number_histories: 5 +seed: 42 +time_list: +- '0:01:05.354713' +- '0:01:05.513818' +- '0:01:09.061250' +- '0:01:12.538492' +- '0:01:11.984464' +val_auc: +- 0.7613853216171265 +- 0.8507994413375854 +- 0.8481254577636719 +- 0.7496821284294128 +- 0.7966626882553101 +val_binary_accuracy: +- 0.9358108043670654 +- 0.9389830231666565 +- 0.9459459185600281 +- 0.9528619647026062 +- 0.9324324131011963 +val_loss: +- 0.6261833906173706 +- 0.38423800468444824 +- 0.32406172156333923 +- 0.3191233277320862 +- 0.3854322135448456 diff --git a/training/results/ClinToxDataset/GraphSAGE/GraphSAGE_hyper.json b/training/results/ClinToxDataset/GraphSAGE/GraphSAGE_hyper.json new file mode 100644 index 00000000..b7a7c62c --- /dev/null +++ b/training/results/ClinToxDataset/GraphSAGE/GraphSAGE_hyper.json @@ -0,0 +1 @@ +{"model": {"class_name": "make_model", "module_name": "kgcnn.literature.GraphSAGE", "config": {"name": "GraphSAGE", "inputs": [{"shape": [null, 41], "name": "node_attributes", "dtype": "float32"}, {"shape": [null, 11], "name": "edge_attributes", "dtype": "float32"}, {"shape": [null, 2], "name": "edge_indices", "dtype": "int64"}, {"shape": [], "name": "total_nodes", "dtype": "int64"}, {"shape": [], "name": "total_edges", "dtype": "int64"}], "cast_disjoint_kwargs": {"padded_disjoint": false}, "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 25, "output_dim": 1}, "node_mlp_args": {"units": [64, 32], "use_bias": true, "activation": ["relu", "linear"]}, "edge_mlp_args": {"units": 64, "use_bias": true, "activation": "relu"}, "pooling_args": {"pooling_method": "scatter_mean"}, "gather_args": {}, "concat_args": {"axis": -1}, "use_edge_features": true, "pooling_nodes_args": {"pooling_method": "scatter_mean"}, "depth": 3, "verbose": 10, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true, false], "units": [64, 32, 1], "activation": ["relu", "relu", "sigmoid"]}}}, "training": {"fit": {"batch_size": 32, "epochs": 100, "validation_freq": 1, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.0005, "learning_rate_stop": 1e-05, "epo_min": 400, "epo": 500, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.005}}, "loss": "binary_crossentropy", "metrics": ["binary_accuracy", {"class_name": "AUC", "config": {"name": "auc"}}]}}, "dataset": {"class_name": "ClinToxDataset", "module_name": "kgcnn.data.datasets.ClinToxDataset", "config": {}, "methods": [{"set_attributes": {}}, {"set_train_test_indices_k_fold": {"n_splits": 5, "random_state": 42, "shuffle": true}}, {"map_list": {"method": "count_nodes_and_edges"}}]}, "data": {"data_unit": ""}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "4.0.0"}} \ No newline at end of file diff --git a/training/results/ESOLDataset/GraphSAGE/GraphSAGE_ESOLDataset_score.yaml b/training/results/ESOLDataset/GraphSAGE/GraphSAGE_ESOLDataset_score.yaml new file mode 100644 index 00000000..e1b5c1e3 --- /dev/null +++ b/training/results/ESOLDataset/GraphSAGE/GraphSAGE_ESOLDataset_score.yaml @@ -0,0 +1,154 @@ +OS: nt_win32 +backend: tensorflow +cuda_available: 'False' +data_unit: mol/L +date_time: '2023-09-28 19:16:55' +device_id: '[LogicalDevice(name=''/device:CPU:0'', device_type=''CPU'')]' +device_memory: '[]' +device_name: '[{}]' +epochs: +- 500 +- 500 +- 500 +- 500 +- 500 +execute_folds: null +kgcnn_version: 4.0.0 +learning_rate: +- 1.4899999769113492e-05 +- 1.4899999769113492e-05 +- 1.4899999769113492e-05 +- 1.4899999769113492e-05 +- 1.4899999769113492e-05 +loss: +- 0.04259883612394333 +- 0.04020535200834274 +- 0.03714504465460777 +- 0.060478441417217255 +- 0.03989425301551819 +max_learning_rate: +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +- 0.0005000000237487257 +max_loss: +- 1.6915653944015503 +- 1.1960229873657227 +- 1.455209493637085 +- 1.155138373374939 +- 1.2369121313095093 +max_scaled_mean_absolute_error: +- 3.5749387741088867 +- 2.547924518585205 +- 3.1352100372314453 +- 2.420825719833374 +- 2.649657726287842 +max_scaled_root_mean_squared_error: +- 6.005917072296143 +- 3.527874708175659 +- 5.4530110359191895 +- 3.501129627227783 +- 4.414116382598877 +max_val_loss: +- 0.5244345664978027 +- 0.677221417427063 +- 0.5049918293952942 +- 0.6176096200942993 +- 0.5000714659690857 +max_val_scaled_mean_absolute_error: +- 1.0246177911758423 +- 1.5791653394699097 +- 1.0038889646530151 +- 1.300362229347229 +- 0.9890004396438599 +max_val_scaled_root_mean_squared_error: +- 1.3315452337265015 +- 1.9463471174240112 +- 1.3520362377166748 +- 1.6177583932876587 +- 1.2658278942108154 +min_learning_rate: +- 1.4899999769113492e-05 +- 1.4899999769113492e-05 +- 1.4899999769113492e-05 +- 1.4899999769113492e-05 +- 1.4899999769113492e-05 +min_loss: +- 0.04259883612394333 +- 0.03956563398241997 +- 0.03714504465460777 +- 0.0502941869199276 +- 0.037707339972257614 +min_scaled_mean_absolute_error: +- 0.09039484709501266 +- 0.08490517735481262 +- 0.08079701662063599 +- 0.1068895161151886 +- 0.08060325682163239 +min_scaled_root_mean_squared_error: +- 0.15949144959449768 +- 0.19087107479572296 +- 0.17207421362400055 +- 0.20473530888557434 +- 0.14576339721679688 +min_val_loss: +- 0.23218749463558197 +- 0.22514306008815765 +- 0.212930828332901 +- 0.2065148651599884 +- 0.20075100660324097 +min_val_scaled_mean_absolute_error: +- 0.4822452962398529 +- 0.4742461144924164 +- 0.5108270645141602 +- 0.48348742723464966 +- 0.4364219903945923 +min_val_scaled_root_mean_squared_error: +- 0.7145922780036926 +- 0.6578125357627869 +- 0.7568360567092896 +- 0.6597887873649597 +- 0.605344831943512 +model_class: make_model +model_name: GraphSAGE +model_version: '2023-09-18' +multi_target_indices: null +number_histories: 5 +scaled_mean_absolute_error: +- 0.09039484709501266 +- 0.08707469701766968 +- 0.08079701662063599 +- 0.12755875289440155 +- 0.085512176156044 +scaled_root_mean_squared_error: +- 0.15949144959449768 +- 0.19087107479572296 +- 0.17207421362400055 +- 0.22107534110546112 +- 0.14576339721679688 +seed: 42 +time_list: +- '0:02:35.692576' +- '0:02:34.830918' +- '0:02:53.413417' +- '0:03:09.658550' +- '0:03:26.535506' +val_loss: +- 0.24865631759166718 +- 0.2364959716796875 +- 0.21403230726718903 +- 0.2065148651599884 +- 0.21513915061950684 +val_scaled_mean_absolute_error: +- 0.5077494978904724 +- 0.485759437084198 +- 0.5120793581008911 +- 0.48348742723464966 +- 0.4479927122592926 +val_scaled_root_mean_squared_error: +- 0.7423257231712341 +- 0.6686792373657227 +- 0.7884992361068726 +- 0.6775819659233093 +- 0.6139715313911438 diff --git a/training/results/ESOLDataset/GraphSAGE/GraphSAGE_hyper.json b/training/results/ESOLDataset/GraphSAGE/GraphSAGE_hyper.json new file mode 100644 index 00000000..f74c82b3 --- /dev/null +++ b/training/results/ESOLDataset/GraphSAGE/GraphSAGE_hyper.json @@ -0,0 +1 @@ +{"model": {"class_name": "make_model", "module_name": "kgcnn.literature.GraphSAGE", "config": {"name": "GraphSAGE", "inputs": [{"shape": [null, 41], "name": "node_attributes", "dtype": "float32"}, {"shape": [null, 11], "name": "edge_attributes", "dtype": "float32"}, {"shape": [null, 2], "name": "edge_indices", "dtype": "int64"}, {"shape": [], "name": "total_nodes", "dtype": "int64"}, {"shape": [], "name": "total_edges", "dtype": "int64"}], "cast_disjoint_kwargs": {}, "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 32, "output_dim": 32}, "node_mlp_args": {"units": [64, 32], "use_bias": true, "activation": ["relu", "linear"]}, "edge_mlp_args": {"units": 64, "use_bias": true, "activation": "relu"}, "pooling_args": {"pooling_method": "scatter_mean"}, "gather_args": {}, "concat_args": {"axis": -1}, "use_edge_features": true, "pooling_nodes_args": {"pooling_method": "scatter_sum"}, "depth": 3, "verbose": 10, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true, false], "units": [64, 32, 1], "activation": ["relu", "relu", "linear"]}}}, "training": {"fit": {"batch_size": 32, "epochs": 500, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.0005, "learning_rate_stop": 1e-05, "epo_min": 400, "epo": 500, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.005}}, "loss": "mean_absolute_error"}, "cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "scaler": {"class_name": "StandardLabelScaler", "config": {"with_std": true, "with_mean": true, "copy": true}}}, "data": {"data_unit": "mol/L"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "4.0.0"}, "dataset": {"class_name": "ESOLDataset", "module_name": "kgcnn.data.datasets.ESOLDataset", "config": {}, "methods": [{"set_attributes": {}}, {"map_list": {"method": "count_nodes_and_edges"}}]}} \ No newline at end of file diff --git a/training/results/README.md b/training/results/README.md index c23481f5..326928b9 100644 --- a/training/results/README.md +++ b/training/results/README.md @@ -10,36 +10,54 @@ due to their file size. To show overall best test error run ``python3 summary.py --min_max True``. If not noted otherwise, we use a (fixed) random k-fold split for validation errors. +#### ClinToxDataset + +ClinTox (MoleculeNet) consists of 1478 compounds as smiles and data of drugs approved by the FDA and those that have failed clinical trials for toxicity reasons. We use random 5-fold cross-validation. The first label 'approved' is chosen as target. + +| model | kgcnn | epochs | Accuracy | AUC(ROC) | +|:----------|:--------|---------:|:-----------------------|:-----------------------| +| DMPNN | 4.0.0 | 50 | 0.9480 ± 0.0138 | 0.8297 ± 0.0568 | +| GAT | 4.0.0 | 50 | **0.9480 ± 0.0070** | 0.8512 ± 0.0468 | +| GATv2 | 4.0.0 | 50 | 0.9372 ± 0.0155 | **0.8587 ± 0.0754** | +| GCN | 4.0.0 | 50 | 0.9432 ± 0.0155 | 0.8555 ± 0.0593 | +| GIN | 4.0.0 | 50 | 0.9412 ± 0.0034 | 0.8066 ± 0.0636 | +| GraphSAGE | 4.0.0 | 100 | 0.9412 ± 0.0073 | 0.8013 ± 0.0422 | +| Schnet | 4.0.0 | 50 | 0.9277 ± 0.0102 | 0.6562 ± 0.0760 | + +#### CoraDataset + +Cora Dataset of 19793 publications and 8710 sparse node attributes and 70 node classes. Here we use random 5-fold cross-validation on nodes. + +| model | kgcnn | epochs | Categorical accuracy | +|:--------|:--------|---------:|:-----------------------| +| GAT | 4.0.0 | 250 | 0.6132 ± 0.0115 | +| GCN | 4.0.0 | 300 | **0.6232 ± 0.0054** | +| GIN | 4.0.0 | 800 | 0.5170 ± 0.2336 | + #### CoraLuDataset Cora Dataset after Lu et al. (2003) of 2708 publications and 1433 sparse attributes and 7 node classes. Here we use random 5-fold cross-validation on nodes. | model | kgcnn | epochs | Categorical accuracy | |:----------|:--------|---------:|:-----------------------| +| DMPNN | 4.0.0 | 300 | 0.8357 ± 0.0156 | | GAT | 4.0.0 | 250 | 0.8464 ± 0.0105 | | GATv2 | 4.0.0 | 250 | 0.8331 ± 0.0104 | | GCN | 4.0.0 | 300 | 0.8072 ± 0.0109 | | GIN | 4.0.0 | 500 | 0.8279 ± 0.0170 | | GraphSAGE | 4.0.0 | 500 | **0.8497 ± 0.0100** | -#### CoraDataset - -Cora Dataset of 19793 publications and 8710 sparse node attributes and 70 node classes. Here we use random 5-fold cross-validation on nodes. - -| model | kgcnn | epochs | Categorical accuracy | -|:--------|:--------|---------:|:-----------------------| -| GCN | 4.0.0 | 300 | **0.6232 ± 0.0054** | - #### ESOLDataset ESOL consists of 1128 compounds as smiles and their corresponding water solubility in log10(mol/L). We use random 5-fold cross-validation. -| model | kgcnn | epochs | MAE [log mol/L] | RMSE [log mol/L] | -|:--------|:--------|---------:|:-------------------|:-------------------| -| DMPNN | 4.0.0 | 300 | 0.4556 ± 0.0281 | 0.6471 ± 0.0299 | -| GAT | 4.0.0 | 500 | **nan ± nan** | **nan ± nan** | -| GCN | 4.0.0 | 800 | 0.4613 ± 0.0205 | 0.6534 ± 0.0513 | -| Schnet | 4.0.0 | 800 | nan ± nan | nan ± nan | +| model | kgcnn | epochs | MAE [log mol/L] | RMSE [log mol/L] | +|:----------|:--------|---------:|:-------------------|:-------------------| +| DMPNN | 4.0.0 | 300 | 0.4556 ± 0.0281 | 0.6471 ± 0.0299 | +| GAT | 4.0.0 | 500 | **nan ± nan** | **nan ± nan** | +| GCN | 4.0.0 | 800 | 0.4613 ± 0.0205 | 0.6534 ± 0.0513 | +| GraphSAGE | 4.0.0 | 500 | 0.4874 ± 0.0228 | 0.6982 ± 0.0608 | +| Schnet | 4.0.0 | 800 | nan ± nan | nan ± nan | #### MatProjectJdft2dDataset @@ -49,15 +67,6 @@ Materials Project dataset from Matbench with 636 crystal structures and their co |:--------------------------|:--------|---------:|:-------------------------|:--------------------------| | Schnet.make_crystal_model | 4.0.0 | 800 | **47.0970 ± 12.1636** | **121.0402 ± 38.7995** | -#### ClinToxDataset - -ClinTox (MoleculeNet) consists of 1478 compounds as smiles and data of drugs approved by the FDA and those that have failed clinical trials for toxicity reasons. We use random 5-fold cross-validation. The first label 'approved' is chosen as target. - -| model | kgcnn | epochs | Accuracy | AUC(ROC) | -|:--------|:--------|---------:|:-----------------------|:-----------------------| -| GCN | 4.0.0 | 200 | **0.6911 ± 0.1028** | 0.7910 ± 0.0593 | -| GIN | 4.0.0 | 50 | 0.3447 ± 0.1142 | **0.8066 ± 0.0636** | - #### MD17Dataset Energies and forces for molecular dynamics trajectories of eight organic molecules. All geometries in A, energy labels in kcal/mol and force labels in kcal/mol/A. We use preset train-test split. Training on 1000 geometries, test on 500/1000 geometries. Errors are MAE for forces. Results are for the CCSD and CCSD(T) data in MD17. diff --git a/training/results/summary.py b/training/results/summary.py index 31b470e4..c3c46924 100644 --- a/training/results/summary.py +++ b/training/results/summary.py @@ -12,9 +12,22 @@ show_min_max = args["min_max"] benchmark_datasets = { - "CoraLuDataset": { + "ClinToxDataset": { "general_info": [ - "Cora Dataset after Lu et al. (2003) of 2708 publications and 1433 sparse attributes and 7 node classes. ", + "ClinTox (MoleculeNet) consists of 1478 compounds as smiles and ", + "data of drugs approved by the FDA and those that have failed clinical trials for toxicity reasons. ", + "We use random 5-fold cross-validation. The first label 'approved' is chosen as target." + ], + "targets": [ + {"metric": "val_binary_accuracy", "name": "Accuracy", "find_best": "max"}, + {"metric": "val_auc", "name": "AUC(ROC)", "find_best": "max"}, + {"metric": "max_val_binary_accuracy", "name": "*Max. Accuracy*", "find_best": "max", "is_min_max": True}, + {"metric": "max_val_auc", "name": "*Max. AUC*", "find_best": "max", "is_min_max": True} + ] + }, + "CoraDataset": { + "general_info": [ + "Cora Dataset of 19793 publications and 8710 sparse node attributes and 70 node classes. ", "Here we use random 5-fold cross-validation on nodes. ", ], "targets": [ @@ -23,9 +36,9 @@ "is_min_max": True}, ] }, - "CoraDataset": { + "CoraLuDataset": { "general_info": [ - "Cora Dataset of 19793 publications and 8710 sparse node attributes and 70 node classes. ", + "Cora Dataset after Lu et al. (2003) of 2708 publications and 1433 sparse attributes and 7 node classes. ", "Here we use random 5-fold cross-validation on nodes. ", ], "targets": [ @@ -268,19 +281,6 @@ # {"metric": "max_val_AUC_no_nan", "name": "*Max. AUC*", "find_best": "max", "is_min_max": True} # ] # }, - "ClinToxDataset": { - "general_info": [ - "ClinTox (MoleculeNet) consists of 1478 compounds as smiles and ", - "data of drugs approved by the FDA and those that have failed clinical trials for toxicity reasons. ", - "We use random 5-fold cross-validation. The first label 'approved' is chosen as target." - ], - "targets": [ - {"metric": "val_binary_accuracy", "name": "Accuracy", "find_best": "max"}, - {"metric": "val_auc", "name": "AUC(ROC)", "find_best": "max"}, - {"metric": "max_val_binary_accuracy", "name": "*Max. Accuracy*", "find_best": "max", "is_min_max": True}, - {"metric": "max_val_auc", "name": "*Max. AUC*", "find_best": "max", "is_min_max": True} - ] - }, # "QM7Dataset": { # "general_info": [ # "QM7 dataset is a subset of GDB-13. ", diff --git a/training/train_graph.py b/training/train_graph.py index 575a52b8..6a737db6 100644 --- a/training/train_graph.py +++ b/training/train_graph.py @@ -21,8 +21,8 @@ # for training and model setup. parser = argparse.ArgumentParser(description='Train a GNN on a graph regression or classification task.') parser.add_argument("--hyper", required=False, help="Filepath to hyperparameter config file (.py or .json).", - default="hyper/hyper_clintox.py") -parser.add_argument("--category", required=False, help="Graph model to train.", default="GATv2") + default="hyper/hyper_esol.py") +parser.add_argument("--category", required=False, help="Graph model to train.", default="GraphSAGE") parser.add_argument("--model", required=False, help="Graph model to train.", default=None) parser.add_argument("--dataset", required=False, help="Name of the dataset.", default=None) parser.add_argument("--make", required=False, help="Name of the class for model.", default=None) diff --git a/training/train_node.py b/training/train_node.py index ff43dc1d..7b2dddaa 100644 --- a/training/train_node.py +++ b/training/train_node.py @@ -20,7 +20,7 @@ parser = argparse.ArgumentParser(description='Train a GNN on a Citation dataset.') parser.add_argument("--hyper", required=False, help="Filepath to hyperparameter config file (.py or .json).", default="hyper/hyper_cora.py") -parser.add_argument("--category", required=False, help="Graph model to train.", default="GAT") +parser.add_argument("--category", required=False, help="Graph model to train.", default="GIN") parser.add_argument("--model", required=False, help="Graph model to train.", default=None) parser.add_argument("--dataset", required=False, help="Name of the dataset.", default=None) parser.add_argument("--make", required=False, help="Name of the class for model.", default=None)