-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
155 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import tensorflow as tf | ||
|
||
|
||
def scatter_min(indices, values, shape): | ||
return tf.scatter_min(indices, values, shape) | ||
|
||
def scatter_max(indices, values, shape): | ||
return tf.scatter_nd(indices, values, shape) | ||
|
||
def scatter_mean(indices, values, shape): | ||
|
||
return tf.scatter_nd(indices, values, shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from kgcnn import backend | ||
from keras_core.backend import KerasTensor | ||
from keras_core.backend import any_symbolic_tensors | ||
from keras_core.ops.operation import Operation | ||
|
||
|
||
class ScatterMax(Operation): | ||
def call(self, indices, values, shape): | ||
return backend.scatter_max(indices, values, shape) | ||
|
||
def compute_output_spec(self, indices, values, shape): | ||
return KerasTensor(shape, dtype=values.dtype) | ||
|
||
|
||
def scatter_max(indices, values, shape): | ||
if any_symbolic_tensors((indices, values, shape)): | ||
return ScatterMax().symbolic_call(indices, values, shape) | ||
return backend.scatter_max(indices, values, shape) | ||
|
||
|
||
class ScatterMin(Operation): | ||
def call(self, indices, values, shape): | ||
return backend.scatter_min(indices, values, shape) | ||
|
||
def compute_output_spec(self, indices, values, shape): | ||
return KerasTensor(shape, dtype=values.dtype) | ||
|
||
|
||
def scatter_min(indices, values, shape): | ||
if any_symbolic_tensors((indices, values, shape)): | ||
return ScatterMin().symbolic_call(indices, values, shape) | ||
return backend.scatter_min(indices, values, shape) | ||
|
||
|
||
class ScatterMean(Operation): | ||
def call(self, indices, values, shape): | ||
return backend.scatter_mean(indices, values, shape) | ||
|
||
def compute_output_spec(self, indices, values, shape): | ||
return KerasTensor(shape, dtype=values.dtype) | ||
|
||
|
||
def scatter_mean(indices, values, shape): | ||
if any_symbolic_tensors((indices, values, shape)): | ||
return ScatterMean().symbolic_call(indices, values, shape) | ||
return backend.scatter_mean(indices, values, shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
{"model": {"class_name": "make_model", "module_name": "kgcnn.literature_core.GCN", "config": {"name": "GCN", "inputs": [{"shape": [null, 41], "name": "node_attributes", "dtype": "float32"}, {"shape": [null, 1], "name": "edge_weights", "dtype": "float32"}, {"shape": [null, 2], "name": "edge_indices", "dtype": "int64"}, {"shape": [], "name": "graph_size", "dtype": "int64"}, {"shape": [], "name": "edge_count", "dtype": "int64"}], "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 25, "output_dim": 1}, "gcn_args": {"units": 140, "use_bias": true, "activation": "relu"}, "depth": 5, "verbose": 10, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true, false], "units": [140, 70, 1], "activation": ["relu", "relu", "linear"]}}}, "training": {"fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.001, "learning_rate_stop": 5e-05, "epo_min": 250, "epo": 800, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.001}}, "loss": "mean_absolute_error", "jit_compile": false}, "cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "scaler": {"class_name": "StandardScaler", "config": {"with_std": true, "with_mean": true, "copy": true}}}, "dataset": {"class_name": "ESOLDataset", "module_name": "kgcnn.data.datasets.ESOLDataset", "config": {}, "methods": [{"set_attributes": {}}, {"map_list": {"method": "normalize_edge_weights_sym"}}]}, "data": {"data_unit": "mol/L"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "2.0.3"}} | ||
{"model": {"class_name": "make_model", "module_name": "kgcnn.literature_core.GCN", "config": {"name": "GCN", "inputs": [{"shape": [null, 41], "name": "node_attributes", "dtype": "float32"}, {"shape": [null, 1], "name": "edge_weights", "dtype": "float32"}, {"shape": [null, 2], "name": "edge_indices", "dtype": "int64"}, {"shape": [], "name": "graph_size", "dtype": "int64"}, {"shape": [], "name": "edge_count", "dtype": "int64"}], "input_node_embedding": {"input_dim": 95, "output_dim": 64}, "input_edge_embedding": {"input_dim": 25, "output_dim": 1}, "gcn_args": {"units": 140, "use_bias": true, "activation": "relu"}, "depth": 5, "verbose": 10, "output_embedding": "graph", "output_mlp": {"use_bias": [true, true, false], "units": [140, 70, 1], "activation": ["relu", "relu", "linear"]}}}, "training": {"fit": {"batch_size": 32, "epochs": 800, "validation_freq": 10, "verbose": 2, "callbacks": [{"class_name": "kgcnn>LinearLearningRateScheduler", "config": {"learning_rate_start": 0.001, "learning_rate_stop": 5e-05, "epo_min": 250, "epo": 800, "verbose": 0}}]}, "compile": {"optimizer": {"class_name": "Adam", "config": {"learning_rate": 0.001}}, "loss": "mean_absolute_error", "jit_compile": true}, "cross_validation": {"class_name": "KFold", "config": {"n_splits": 5, "random_state": 42, "shuffle": true}}, "scaler": {"class_name": "StandardScaler", "config": {"with_std": true, "with_mean": true, "copy": true}}}, "dataset": {"class_name": "ESOLDataset", "module_name": "kgcnn.data.datasets.ESOLDataset", "config": {}, "methods": [{"set_attributes": {}}, {"map_list": {"method": "normalize_edge_weights_sym"}}]}, "data": {"data_unit": "mol/L"}, "info": {"postfix": "", "postfix_file": "", "kgcnn_version": "2.0.3"}} |