Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 13, 2023
1 parent 6597501 commit f6589f4
Show file tree
Hide file tree
Showing 5 changed files with 427 additions and 16 deletions.
5 changes: 5 additions & 0 deletions kgcnn/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd
import os
from sklearn.model_selection import KFold
from kgcnn.io.loader import experimental_tf_disjoint_list_generator
# import typing as t
from typing import Union, List, Callable, Dict, Optional
# from collections.abc import MutableSequence
Expand Down Expand Up @@ -328,6 +329,10 @@ def rename_property_on_graphs(self, old_property_name: str, new_property_name: s
set = assign_property
get = obtain_property

def tf_disjoint_data_generator(self, inputs, outputs, **kwargs):
module_logger.info("Dataloader is experimental and not fully tested nor stable.")
return experimental_tf_disjoint_list_generator(self, inputs=inputs, outputs=outputs, **kwargs)


class MemoryGraphDataset(MemoryGraphList):
r"""Dataset class for lists of graph tensor dictionaries stored on file and fit into memory.
Expand Down
92 changes: 92 additions & 0 deletions kgcnn/io/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import keras_core as ks
import numpy as np
import tensorflow as tf


def experimental_tf_disjoint_list_generator(graphs,
inputs,
outputs,
has_nodes=True,
has_edges=True,
has_graph_state=False,
batch_size=32,
shuffle=True):
def generator():
dataset_size = len(graphs)
data_index = np.arange(dataset_size)

if shuffle:
np.random.shuffle(data_index)

for batch_index in range(0, dataset_size, batch_size):
idx = data_index[batch_index:batch_index + batch_size]
graphs_batch = [graphs[i] for i in idx]

batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = [None for _ in range(6)]
out = []
inputs_pos = 0
for j in range(int(has_nodes)):
array_list = [x[inputs[inputs_pos]["name"]] for x in graphs_batch]
out.append(np.concatenate(array_list, axis=0))
inputs_pos += 1
if j == 0:
count_nodes = np.array([len(x) for x in array_list], dtype="int64")
batch_id_node = np.repeat(np.arange(len(array_list), dtype="int64"), repeats=count_nodes)
node_id = np.concatenate([np.arange(x, dtype="int64") for x in count_nodes], axis=0)

for j in range(int(has_edges)):
array_list = [x[inputs[inputs_pos]["name"]] for x in graphs_batch]
out.append(np.concatenate(array_list, axis=0, dtype=inputs[inputs_pos]["dtype"]))
inputs_pos += 1

for j in range(int(has_graph_state)):
array_list = [x[inputs[inputs_pos]["name"]] for x in graphs_batch]
out.append(np.array(array_list, dtype=inputs[inputs_pos]["dtype"]))
inputs_pos += 1

# Indices
array_list = [x[inputs[inputs_pos]["name"]] for x in graphs_batch]
count_edges = np.array([len(x) for x in array_list], dtype="int64")
batch_id_edge = np.repeat(np.arange(len(array_list), dtype="int64"), repeats=count_edges)
edge_id = np.concatenate([np.arange(x, dtype="int64") for x in count_edges], axis=0)
edge_indices_flatten = np.concatenate(array_list, axis=0)

node_splits = np.pad(np.cumsum(count_nodes), [[1, 0]])
offset_edge_indices = np.expand_dims(np.repeat(node_splits[:-1], count_edges), axis=-1)
disjoint_indices = edge_indices_flatten + offset_edge_indices
disjoint_indices = np.transpose(disjoint_indices)
out.append(disjoint_indices)

out = out + [batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges]

if isinstance(outputs, list):
out_y = []
for k in range(len(outputs)):
array_list = [x[outputs[k]["name"]] for x in graphs_batch]
out_y.append(np.array(array_list, dtype=outputs[k]["dtype"]))
elif isinstance(outputs, dict):
out_y = np.array(
[x[outputs["name"]] for x in graphs_batch], dtype=outputs["dtype"])
else:
raise ValueError()

yield tuple(out), out_y

input_spec = tuple([tf.TensorSpec(shape=tuple([None] + list(x["shape"])), dtype=x["dtype"]) for x in inputs])

if isinstance(outputs, list):
output_spec = tuple([tf.TensorSpec(shape=tuple([None] + list(x["shape"])), dtype=x["dtype"]) for x in outputs])
elif isinstance(outputs, dict):
output_spec = tf.TensorSpec(shape=tuple([None] + list(outputs["shape"])), dtype=outputs["dtype"])
else:
raise ValueError()

data_loader = tf.data.Dataset.from_generator(
generator,
output_signature=(
input_spec,
output_spec
)
)

return data_loader
5 changes: 3 additions & 2 deletions kgcnn/literature/GIN/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
{"shape": (), "name": "total_edges", "dtype": "int64"}
],
"input_tensor_type": "padded",
"cast_disjoint_kwargs": {"padded_disjoint": False},
"cast_disjoint_kwargs": {},
"input_embedding": None, # deprecated
"input_node_embedding": {"input_dim": 95, "output_dim": 64},
"gin_mlp": {"units": [64, 64], "use_bias": True, "activation": ["relu", "linear"],
Expand Down Expand Up @@ -245,7 +245,8 @@ def make_model_edge(inputs: list = None,
# Wrapping disjoint model.
out = model_disjoint_edge(
[n, ed, disjoint_indices, batch_id_node, count_nodes],
use_node_embedding=len(inputs[0]['shape']) < 2, use_edge_embedding=len(inputs[1]['shape']) < 2,
use_node_embedding="float" not in inputs[0]['dtype'],
use_edge_embedding="float" not in inputs[1]['dtype'],
input_node_embedding=input_node_embedding, input_edge_embedding=input_edge_embedding,
depth=depth, gin_args=gin_args, gin_mlp=gin_mlp, last_mlp=last_mlp,
dropout=dropout, output_embedding=output_embedding, output_mlp=output_mlp
Expand Down
25 changes: 18 additions & 7 deletions kgcnn/models/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,28 @@ def template_cast_output(model_outputs,
out = CastDisjointToBatchedGraphState(**cast_disjoint_kwargs)(out)
elif output_embedding == 'node':
if output_tensor_type in ["padded", "masked"]:
if input_tensor_type in ["padded", "masked"]:
if "static_batched_node_output_shape" in cast_disjoint_kwargs:
out_node_shape = cast_disjoint_kwargs["static_batched_node_output_shape"]
else:
out_node_shape = None
out = CastDisjointToBatchedAttributes(static_output_shape=out_node_shape, **cast_disjoint_kwargs)(
[out, batch_id_node, node_id, count_nodes])
if "static_batched_node_output_shape" in cast_disjoint_kwargs:
out_node_shape = cast_disjoint_kwargs["static_batched_node_output_shape"]
else:
out_node_shape = None
out = CastDisjointToBatchedAttributes(static_output_shape=out_node_shape, **cast_disjoint_kwargs)(
[out, batch_id_node, node_id, count_nodes])
if output_tensor_type in ["ragged", "jagged"]:
out = CastDisjointToRaggedAttributes()([out, batch_id_node, node_id, count_nodes])
else:
out = CastDisjointToBatchedGraphState(**cast_disjoint_kwargs)(out)
elif output_embedding == 'edge':
if output_tensor_type in ["padded", "masked"]:
if "static_batched_edge_output_shape" in cast_disjoint_kwargs:
out_edge_shape = cast_disjoint_kwargs["static_batched_edge_output_shape"]
else:
out_edge_shape = None
out = CastDisjointToBatchedAttributes(static_output_shape=out_edge_shape, **cast_disjoint_kwargs)(
[out, batch_id_edge, edge_id, count_edges])
if output_tensor_type in ["ragged", "jagged"]:
out = CastDisjointToRaggedAttributes()([out, batch_id_edge, edge_id, count_edges])
else:
out = CastDisjointToBatchedGraphState(**cast_disjoint_kwargs)(out)
else:
raise NotImplementedError()

Expand Down
316 changes: 309 additions & 7 deletions notebooks/tutorial_model_loading_options.ipynb

Large diffs are not rendered by default.

0 comments on commit f6589f4

Please sign in to comment.