Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 5, 2023
1 parent 153692a commit 057bce0
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 23 deletions.
2 changes: 1 addition & 1 deletion SDP.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ Software Development Plan (SDP)
* [ ] Make a ``tf_dataset()`` function to return a generator dataset from `Graphlist` .
* [ ] Add ``JARVISDataset`` . There is already a (yet not fully) port for `kgcnn` .
* [ ] Add package wide Logger Level to change.
* [ ] Training scripts need all seed for maximum reproducibility.
* [x] Training scripts need all seed for maximum reproducibility.
8 changes: 4 additions & 4 deletions kgcnn/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

# Import backend functions.
if backend() == "tensorflow":
from kgcnn.backend.tensorflow import (scatter_max, scatter_min, scatter_mean)
from kgcnn.backend.tensorflow import (scatter_sum, scatter_max, scatter_min, scatter_mean)
elif backend() == "jax":
from kgcnn.backend.tensorflow import (scatter_max, scatter_min, scatter_mean)
from kgcnn.backend.tensorflow import (scatter_sum, scatter_max, scatter_min, scatter_mean)
elif backend() == "torch":
from kgcnn.backend.tensorflow import (scatter_max, scatter_min, scatter_mean)
from kgcnn.backend.tensorflow import (scatter_sum, scatter_max, scatter_min, scatter_mean)
elif backend() == "numpy":
from kgcnn.backend.tensorflow import (scatter_max, scatter_min, scatter_mean)
from kgcnn.backend.tensorflow import (scatter_sum, scatter_max, scatter_min, scatter_mean)
else:
raise ValueError(f"Unable to import backend : {backend()}")
4 changes: 4 additions & 0 deletions kgcnn/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ def scatter_max(indices, values, shape):
def scatter_mean(indices, values, shape):
counts = tf.scatter_nd(indices, tf.ones_like(values), shape)
return tf.scatter_nd(indices, values, shape)/counts


def scatter_sum(indices, values, shape):
return tf.scatter_nd(indices, values, shape)
26 changes: 21 additions & 5 deletions kgcnn/layers_core/aggr.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,31 @@
import keras_core as ks
import keras_core.saving
from keras_core.layers import Layer
from keras_core import ops
from kgcnn.ops_core.scatter import scatter_min, scatter_mean, scatter_max, scatter_sum


# @keras_core.saving.register_keras_serializable()
@ks.saving.register_keras_serializable(package='kgcnn', name='Aggregate')
class Aggregate(Layer):

def __init__(self, pooling_method: str = "sum", axis=0, **kwargs):
def __init__(self, pooling_method: str = "scatter_sum", axis=0, **kwargs):
super(Aggregate, self).__init__(**kwargs)
self.pooling_method = pooling_method
self.axis = axis
if axis != 0:
raise NotImplementedError
pooling_by_name = {
"scatter_sum": scatter_sum,
"scatter_mean": scatter_mean,
"scatter_max": scatter_max,
"scatter_min": scatter_min,
"segment_sum": None,
"segment_mean": None,
"segment_max": None,
"segment_min": None
}
self._pool_method = pooling_by_name[pooling_method]
self._use_scatter = "scatter" in pooling_method

def build(self, input_shape):
super(Aggregate, self).build(input_shape)
Expand All @@ -22,10 +37,11 @@ def compute_output_shape(self, input_shape):

def call(self, inputs, **kwargs):
x, index, dim_size = inputs
# For test only sum scatter, no segment operation no other poolings etc.
# will add all poolings here.
shape = ops.concatenate([dim_size, ops.shape(x)[1:]])
return ops.scatter(ops.expand_dims(index, axis=-1), x, shape=shape)
if self._use_scatter:
return self._pool_method(ops.expand_dims(index, axis=-1), x, shape=shape)
else:
raise NotImplementedError


class AggregateLocalEdges(Layer):
Expand Down
28 changes: 21 additions & 7 deletions kgcnn/ops_core/scatter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from kgcnn import backend
import kgcnn.backend as kgcnn_backend
from keras_core.backend import KerasTensor
from keras_core.backend import any_symbolic_tensors
from keras_core.ops.operation import Operation


class ScatterMax(Operation):
def call(self, indices, values, shape):
return backend.scatter_max(indices, values, shape)
return kgcnn_backend.scatter_max(indices, values, shape)

def compute_output_spec(self, indices, values, shape):
return KerasTensor(shape, dtype=values.dtype)
Expand All @@ -15,12 +15,12 @@ def compute_output_spec(self, indices, values, shape):
def scatter_max(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return ScatterMax().symbolic_call(indices, values, shape)
return backend.scatter_max(indices, values, shape)
return kgcnn_backend.scatter_max(indices, values, shape)


class ScatterMin(Operation):
def call(self, indices, values, shape):
return backend.scatter_min(indices, values, shape)
return kgcnn_backend.scatter_min(indices, values, shape)

def compute_output_spec(self, indices, values, shape):
return KerasTensor(shape, dtype=values.dtype)
Expand All @@ -29,12 +29,12 @@ def compute_output_spec(self, indices, values, shape):
def scatter_min(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return ScatterMin().symbolic_call(indices, values, shape)
return backend.scatter_min(indices, values, shape)
return kgcnn_backend.scatter_min(indices, values, shape)


class ScatterMean(Operation):
def call(self, indices, values, shape):
return backend.scatter_mean(indices, values, shape)
return kgcnn_backend.scatter_mean(indices, values, shape)

def compute_output_spec(self, indices, values, shape):
return KerasTensor(shape, dtype=values.dtype)
Expand All @@ -43,4 +43,18 @@ def compute_output_spec(self, indices, values, shape):
def scatter_mean(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return ScatterMean().symbolic_call(indices, values, shape)
return backend.scatter_mean(indices, values, shape)
return kgcnn_backend.scatter_mean(indices, values, shape)


class ScatterSum(Operation):
def call(self, indices, values, shape):
return kgcnn_backend.scatter_sum(indices, values, shape)

def compute_output_spec(self, indices, values, shape):
return KerasTensor(shape, dtype=values.dtype)


def scatter_sum(indices, values, shape):
if any_symbolic_tensors((indices, values, shape)):
return ScatterSum().symbolic_call(indices, values, shape)
return kgcnn_backend.scatter_sum(indices, values, shape)
12 changes: 6 additions & 6 deletions training_core/results/ESOLDataset/GCN/GCN_ESOLDataset_score.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
data_unit: mol/L
date_time: '2023-09-05 13:02:37'
date_time: '2023-09-05 15:13:24'
epochs:
- 800
- 800
Expand Down Expand Up @@ -123,11 +123,11 @@ scaled_root_mean_squared_error:
- 0.163068026304245
seed: 42
time_list:
- '0:06:16.883915'
- '0:06:34.529023'
- '0:07:15.199847'
- '0:10:23.172025'
- '0:10:06.016519'
- '0:05:59.835995'
- '0:06:20.888448'
- '0:06:34.970782'
- '0:10:05.953646'
- '0:11:27.211699'
val_loss:
- 0.23328134417533875
- 0.21415188908576965
Expand Down
20 changes: 20 additions & 0 deletions training_core/results/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Summary of Benchmark Training

Note that these are the results for models within `kgcnn` implementation, and that training is not always done with optimal hyperparameter or splits, when comparing with literature.
This table is generated automatically from keras history logs.
Model weights and training statistics plots are not uploaded on
[github](https://github.com/aimat-lab/gcnn_keras/tree/master/training/results)
due to their file size.

*Max.* or *Min.* denotes the best test error observed for any epoch during training.
To show overall best test error run ``python3 summary.py --min_max True``.
If not noted otherwise, we use a (fixed) random k-fold split for validation errors.

#### ESOLDataset

ESOL consists of 1128 compounds as smiles and their corresponding water solubility in log10(mol/L). We use random 5-fold cross-validation.

| model | kgcnn | epochs | MAE [log mol/L] | RMSE [log mol/L] |
|:--------|:--------|---------:|:-----------------------|:-----------------------|
| GCN | 3.1.0 | 800 | **0.4741 ± 0.0188** | **0.6946 ± 0.0351** |

0 comments on commit 057bce0

Please sign in to comment.