Skip to content

Commit

Permalink
update keras core.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 11, 2023
1 parent 29e5dcc commit 9d7829c
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 52 deletions.
4 changes: 4 additions & 0 deletions kgcnn/backend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@


def scatter_reduce_sum(indices, values, shape):
indices = tf.expand_dims(indices, axis=-1)
return tf.scatter_nd(indices, values, tf.cast(shape, dtype="int64"))


def scatter_reduce_min(indices, values, shape):
indices = tf.expand_dims(indices, axis=-1)
target = tf.fill(shape, values.dtype.limits[1], dtype=values.dtype)
return tf.tensor_scatter_nd_min(target, indices, values)


def scatter_reduce_max(indices, values, shape):
indices = tf.expand_dims(indices, axis=-1)
target = tf.fill(shape, values.dtype.limits[0], dtype=values.dtype)
return tf.tensor_scatter_nd_max(target, indices, values)


def scatter_reduce_mean(indices, values, shape):
indices = tf.expand_dims(indices, axis=-1)
counts = tf.scatter_nd(indices, tf.ones_like(values), shape)
return tf.scatter_nd(indices, values, shape)/counts

Expand Down
5 changes: 5 additions & 0 deletions kgcnn/backend/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,30 @@


def scatter_reduce_sum(indices, values, shape):
indices = torch.unsqueeze(indices, dim=-1)
return torch.zeros(*shape, dtype=values.dtype).scatter_reduce(
0, torch.broadcast_to(indices, values.shape), values, reduce='sum')


def scatter_reduce_min(indices, values, shape):
indices = torch.unsqueeze(indices, dim=-1)
return torch.zeros(*shape, dtype=values.dtype).scatter_reduce(
0, torch.broadcast_to(indices, values.shape), values, reduce='amin', include_self=False)


def scatter_reduce_max(indices, values, shape):
indices = torch.unsqueeze(indices, dim=-1)
return torch.zeros(*shape, dtype=values.dtype).scatter_reduce(
0, torch.broadcast_to(indices, values.shape), values, reduce='amax', include_self=False)


def scatter_reduce_mean(indices, values, shape):
indices = torch.unsqueeze(indices, dim=-1)
return torch.zeros(*shape, dtype=values.dtype).scatter_reduce(
0, torch.broadcast_to(indices, values.shape), values, reduce='mean', include_self=False)


def scatter_reduce_softmax(indices, values, shape):
indices = torch.unsqueeze(indices, dim=-1)
return torch.zeros(*shape, dtype=values.dtype).scatter_reduce(
0, torch.broadcast_to(indices, values.shape), values, reduce='sum')
4 changes: 2 additions & 2 deletions kgcnn/layers_core/aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, pooling_method: str = "scatter_sum", axis=0, **kwargs):

def build(self, input_shape):
# Nothing to build here. No sub-layers.
super(Aggregate, self).build(input_shape)
self.built = True

def compute_output_shape(self, input_shape):
assert len(input_shape) == 3
Expand All @@ -41,7 +41,7 @@ def call(self, inputs, **kwargs):
x, index, reference = inputs
shape = ops.shape(reference)[:1] + ops.shape(x)[1:]
if self._use_scatter:
return self._pool_method(ops.expand_dims(index, axis=-1), x, shape=shape)
return self._pool_method(index, x, shape=shape)
else:
raise NotImplementedError()

Expand Down
41 changes: 21 additions & 20 deletions kgcnn/layers_core/casting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from keras_core.layers import Layer
from keras_core import ops
from keras_core.backend import backend


class CastBatchedGraphIndicesToPyGDisjoint(Layer):
Expand All @@ -16,10 +17,16 @@ def __init__(self, reverse_indices: bool = True, dtype_batch: str = "int64", dty
self.reverse_indices = reverse_indices
self.dtype_index = dtype_index
self.dtype_batch = dtype_batch
# self.supports_jit = supports_jit

def build(self, input_shape):
super(CastBatchedGraphIndicesToPyGDisjoint, self).build(input_shape)

def compute_output_shape(self, input_shape):
return [tuple([None] + list(input_shape[0][2:])),
tuple(list(reversed(input_shape[1][2:])) + [None]),
(None,), (None,), (None,)]

def call(self, inputs: list, **kwargs):
"""Changes node and edge indices into a Pytorch Geometric (PyG) compatible tensor format.
Expand All @@ -28,10 +35,10 @@ def call(self, inputs: list, **kwargs):
- nodes (Tensor): Node features are represented by a keras tensor of shape `(batch, N, F, ...)` ,
where N denotes the number of nodes.
- edge_indices (Tensor): Edge index list have shape `(batch, M, 2)` with the indices of directed
- edge_indices (Tensor): Edge index list have shape `(batch, M, 2)` with the indices of M directed
edges at last axis for each edge corresponding to `edges` .
- nodes_in_batch (Tensor, optional):
- edges_in_batch (Tensor, optional):
- nodes_in_batch (Tensor):
- edges_in_batch (Tensor):
Returns:
Expand All @@ -44,8 +51,7 @@ def call(self, inputs: list, **kwargs):
"""
all_tensor = all([ops.is_tensor(x) for x in inputs])

# Case 1: Padded node and edges tensors but with batch dimension at axis 0.
if all_tensor and len(inputs) == 4:
if all_tensor:
nodes, edge_indices, node_len, edge_len = inputs
node_len = ops.cast(node_len, dtype=self.dtype_batch)
edge_len = ops.cast(edge_len, dtype=self.dtype_batch)
Expand All @@ -58,14 +64,11 @@ def call(self, inputs: list, **kwargs):
edge_indices_flatten = edge_indices[edge_mask]
nodes_flatten = nodes[node_mask]

# Case 2: Fixed sized graphs without batch information.
elif all_tensor and len(inputs) == 2:
nodes, edges, edge_indices = inputs
nodes_flatten = ops.reshape(nodes, [-1] + list(ops.shape(nodes)[2:]))
edge_indices_flatten = ops.reshape(edge_indices, [-1] + list(ops.shape(edge_indices)[2:]))
node_len = ops.repeat(ops.cast([ops.shape(nodes)[1]], dtype=self.dtype_batch), ops.shape(nodes)[0])
edge_len = ops.repeat(ops.cast([ops.shape(edge_indices)[1]], dtype=self.dtype_batch),
ops.shape(edge_indices)[0])
# nodes_flatten = ops.reshape(nodes, [-1] + list(ops.shape(nodes)[2:]))
# edge_indices_flatten = ops.reshape(edge_indices, [-1] + list(ops.shape(edge_indices)[2:]))
# node_len = ops.repeat(ops.cast([ops.shape(nodes)[1]], dtype=self.dtype_batch), ops.shape(nodes)[0])
# edge_len = ops.repeat(ops.cast([ops.shape(edge_indices)[1]], dtype=self.dtype_batch),
# ops.shape(edge_indices)[0])

# Case: Ragged Tensor input.
# As soon as ragged tensors are supported by Keras-Core.
Expand Down Expand Up @@ -103,10 +106,14 @@ def __init__(self, reverse_indices: bool = True, dtype_batch: str = "int64", **k
super(CastBatchedGraphAttributesToPyGDisjoint, self).__init__(**kwargs)
self.reverse_indices = reverse_indices
self.dtype_batch = dtype_batch
self.supports_jit = False

def build(self, input_shape):
super(CastBatchedGraphAttributesToPyGDisjoint, self).build(input_shape)

def compute_output_shape(self, input_shape):
return [tuple([None] + list(input_shape[0][2:])), (None,)]

def call(self, inputs: list, **kwargs):
"""Changes node or edge tensors into a Pytorch Geometric (PyG) compatible tensor format.
Expand All @@ -124,19 +131,13 @@ def call(self, inputs: list, **kwargs):
- counts_in_batch (Tensor): Tensor of lengths for each graph of shape `(batch, )` .
"""
# Case 1: Padded node and edges tensors but with batch dimension at axis 0.
if isinstance(inputs, list):
if all([ops.is_tensor(x) for x in inputs]):
nodes, node_len = inputs
node_len = ops.cast(node_len, dtype=self.dtype_batch)
node_mask = ops.repeat(ops.expand_dims(ops.arange(ops.shape(nodes)[1], dtype=self.dtype_batch), axis=0),
ops.shape(node_len)[0], axis=0) < ops.expand_dims(node_len, axis=-1)
nodes_flatten = nodes[node_mask]

# Case 2: Fixed sized graphs without batch information.
elif ops.is_tensor(inputs):
nodes = inputs
nodes_flatten = ops.reshape(nodes, [-1] + list(ops.shape(nodes)[2:]))
node_len = ops.repeat(ops.cast([ops.shape(nodes)[1]], dtype=self.dtype_batch), ops.shape(nodes)[0])

# Case: Ragged Tensor input.
# As soon as ragged tensors are supported by Keras-Core.

Expand Down
15 changes: 9 additions & 6 deletions kgcnn/literature_core/GCN/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from kgcnn.layers_core.pooling import PoolingNodes
from kgcnn.model.utils import update_model_kwargs
from keras_core.backend import backend as backend_to_use
from keras_core.layers import Activation
from kgcnn.layers_core.aggr import AggregateWeightedLocalEdges
from kgcnn.layers_core.gather import GatherNodesOutgoing

# Keep track of model version from commit date in literature.
__kgcnn_model_version__ = "2023.09.30"
Expand Down Expand Up @@ -97,7 +100,7 @@ def make_model(inputs: list = None,
# Make input
model_inputs = [ks.layers.Input(**x) for x in inputs]
batched_nodes, batched_edges, batched_indices, count_nodes, count_edges = model_inputs
n, disjoint_indices, batch, _, _ = CastBatchedGraphIndicesToPyGDisjoint(**cast_indices_kwargs)([
n, disjoint_indices, _, _, _ = CastBatchedGraphIndicesToPyGDisjoint(**cast_indices_kwargs)([
batched_nodes, batched_indices, count_nodes, count_edges])
e, _ = CastBatchedGraphAttributesToPyGDisjoint()([batched_edges, count_edges])

Expand All @@ -111,13 +114,13 @@ def make_model(inputs: list = None,
n = Dense(gcn_args["units"], use_bias=True, activation='linear')(n) # Map to units

for i in range(0, depth):
n = GCN(**gcn_args)([n, e, disjoint_indices])
# n = GCN(**gcn_args)([n, e, disjoint_indices])

# # Equivalent as:
# no = Dense(gcn_args["units"], activation="linear")(n)
# no = GatherNodesOutgoing()([no, disjoint_indices])
# nu = AggregateWeightedLocalEdges()([n, no, disjoint_indices, e])
# n = Activation(gcn_args["activation"])(nu)
no = Dense(gcn_args["units"], activation="linear")(n)
no = GatherNodesOutgoing()([no, disjoint_indices])
nu = AggregateWeightedLocalEdges()([n, no, disjoint_indices, e])
n = Activation(gcn_args["activation"])(nu)

# Output embedding choice
if output_embedding == "graph":
Expand Down
4 changes: 2 additions & 2 deletions kgcnn/training_core/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import keras_core as ks
import keras_core.callbacks
from keras_core.backend import backend
from kgcnn.utils_core.devices import check_device_cuda
from kgcnn.utils_core.devices import check_device


def load_history_list(file_path, folds):
Expand Down Expand Up @@ -114,7 +114,7 @@ def save_history_score(
result_dict["seed"] = seed
result_dict["backend"] = backend()
result_dict["OS"] = "%s_%s" % (os.name, sys.platform)
result_dict.update(check_device_cuda())
result_dict.update(check_device())
if trajectory_name:
result_dict["trajectory_name"] = trajectory_name

Expand Down
9 changes: 8 additions & 1 deletion kgcnn/utils_core/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
module_logger.setLevel(logging.INFO)


def check_device_cuda():
def check_device():

if backend() == "tensorflow":
import tensorflow as tf
Expand All @@ -33,6 +33,13 @@ def check_device_cuda():
logical_device_list = [x for x in range(torch.cuda.device_count())]
memory_info = [{"allocated": round(torch.cuda.memory_allocated(i)/1024**3, 1),
"cached": round(torch.cuda.memory_reserved(i)/1024**3, 1)} for i in logical_device_list]
elif backend() == "jax":
import jax
jax_devices = jax.devices()
cuda_is_available = any([x.device_kind == "gpu" for x in jax_devices])
physical_device_name = [x for x in jax_devices]
logical_device_list = [x.id for x in jax_devices]
memory_info = [x.memory_stats() for x in jax_devices]

else:
raise NotImplementedError("Backend %s is not supported for `check_device_cuda` .")
Expand Down
42 changes: 25 additions & 17 deletions test_core/test_layers_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from keras_core import ops
from keras_core import testing
from kgcnn.layers_core.casting import CastBatchedGraphIndicesToPyGDisjoint
from kgcnn.layers_core.casting import CastBatchedGraphIndicesToPyGDisjoint, CastBatchedGraphAttributesToPyGDisjoint


class CastBatchedGraphsToPyGDisjointTest(testing.TestCase):
Expand All @@ -17,33 +17,41 @@ class CastBatchedGraphsToPyGDisjointTest(testing.TestCase):
node_len = np.array([1, 2], dtype="int64")
edge_len = np.array([1, 3], dtype="int64")

def test_correctness_lengths(self):
def test_correctness(self):

layer = CastBatchedGraphIndicesToPyGDisjoint()
node_attr, edge_index, batch, _, _ = layer(
[self.nodes, self.edges,
ops.cast(self.edge_indices, dtype="int64"),
[self.nodes, ops.cast(self.edge_indices, dtype="int64"),
self.node_len, self.edge_len
])
self.assertAllClose(node_attr, [[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]])
# self.assertAllClose(edge_attr, [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0]])
self.assertAllClose(edge_index, [[0, 1, 2, 1], [0, 1, 1, 2]])
self.assertAllClose(batch, [0, 1, 1])

def test_correctness_equal_size(self):

layer = CastBatchedGraphIndicesToPyGDisjoint(reverse_indices=False)
node_attr, edge_index, batch, _, _ = layer(
[self.nodes, self.edges, ops.cast(self.edge_indices, dtype="int64")]
)
self.assertAllClose(node_attr, [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
# self.assertAllClose(edge_attr, [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 1.0, 1.0],
# [1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], [-1.0, 1.0, 1.0]])
self.assertAllClose(edge_index, [[0, 0, 1, 1, 2, 2, 3, 3], [0, 1, 0, 1, 2, 3, 2, 3]])
self.assertAllClose(batch, [0, 0, 1, 1])
class TestCastBatchedGraphAttributesToPyGDisjoint(testing.TestCase):

nodes = np.array([[[0.0, 0.0], [0.0, 1.0]], [[1.0, 0.0], [1.0, 1.0]]])
edges = np.array([[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 1.0, 1.0]],
[[1.0, 0.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 0.0], [-1.0, 1.0, 1.0]]])
edge_indices = np.array([[[0, 0], [0, 1], [1, 0], [1, 1]],
[[0, 0], [0, 1], [1, 0], [1, 1]]], dtype="int64")
node_mask = np.array([[True, False], [True, True]])
edge_mask = np.array([[True, False, False, False], [True, True, True, False]])
node_len = np.array([1, 2], dtype="int64")
edge_len = np.array([1, 3], dtype="int64")

def test_correctness(self):

layer = CastBatchedGraphAttributesToPyGDisjoint()
node_attr, _ = layer(
[self.nodes, self.node_len])
self.assertAllClose(node_attr, [[0.0, 0.0], [1.0, 0.0], [1.0, 1.0]])



if __name__ == "__main__":
CastBatchedGraphsToPyGDisjointTest().test_correctness_lengths()
CastBatchedGraphsToPyGDisjointTest().test_correctness_equal_size()
print("Tests passed.")
CastBatchedGraphsToPyGDisjointTest().test_correctness()
TestCastBatchedGraphAttributesToPyGDisjoint().test_correctness()
print("Tests passed.")
2 changes: 1 addition & 1 deletion training_core/hyper/hyper_esol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
},
"compile": {
"optimizer": {"class_name": "Adam", "config": {"learning_rate": 1e-03}},
"loss": "mean_absolute_error", "jit_compile": True
"loss": "mean_absolute_error"
},
"cross_validation": {"class_name": "KFold",
"config": {"n_splits": 5, "random_state": 42, "shuffle": True}},
Expand Down
6 changes: 3 additions & 3 deletions training_core/train_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from kgcnn.model.serial import deserialize as deserialize_model
from kgcnn.data.serial import deserialize as deserialize_dataset
from kgcnn.training_core.hyper import HyperParameter
from kgcnn.utils_core.devices import check_device_cuda
from kgcnn.utils_core.devices import check_device

# Input arguments from command line with default values from example.
# From command line, one can specify the model, dataset and the hyperparameter which contain all configuration
Expand All @@ -32,7 +32,7 @@
print("Input of argparse:", args)

# Check for gpu
check_device_cuda()
check_device()

# Set seed.
np.random.seed(args["seed"])
Expand Down Expand Up @@ -112,7 +112,7 @@
# The metrics from this script is added to the hyperparameter entry for metrics.
model.compile(**hyper.compile(metrics=metrics))
model.summary()

print(model._jit_compile)
# Run keras model-fit and take time for training.
start = time.time()
hist = model.fit(x_train, y_train,
Expand Down

0 comments on commit 9d7829c

Please sign in to comment.