Skip to content

Commit

Permalink
continue keras core integration
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Sep 23, 2023
1 parent 0c09e95 commit 34d3df4
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 21 deletions.
10 changes: 0 additions & 10 deletions docs/source/kgcnn.backend.rst

This file was deleted.

1 change: 0 additions & 1 deletion docs/source/kgcnn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ Subpackages
.. toctree::
:maxdepth: 4

kgcnn.backend
kgcnn.crystal
kgcnn.data
kgcnn.graph
Expand Down
24 changes: 23 additions & 1 deletion kgcnn/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,17 @@ def relocate(self, data_directory: str = None, file_name: str = None, file_direc
return self

def set_multi_target_labels(self, graph_labels: str = "graph_labels", multi_target_indices: list = None,
data_unit: str = None):
data_unit: Union[str, list] = None):
"""Select multiple targets in labels.
Args:
graph_labels (str): Name of the property that holds multiple targets.
multi_target_indices (list): List of indices of targets to select.
data_unit (str, list): Optional list of data units for all labels in `graph_labels` .
Returns:
tuple: List of label names and label units for each target.
"""

labels = np.array(self.obtain_property(graph_labels))
label_names = self.label_names if hasattr(self, "label_names") else None
Expand All @@ -768,6 +778,18 @@ def set_multi_target_labels(self, graph_labels: str = "graph_labels", multi_targ

def set_train_test_indices_k_fold(self, n_splits: int = 5, shuffle: bool = False, random_state: int = None,
train: str = "train", test: str = "test"):
"""Helper function to set train/test indices for each graph in the list from a random k-fold cross-validation.
Args:
n_splits (int): Number of splits.
shuffle (bool): Whether to shuffle indices.
random_state (int): Random seed for split.
train (str): Property to assign train indices to.
test (str): Property to assign test indices to.
Returns:
None.
"""
kf = KFold(n_splits=n_splits, shuffle=shuffle, random_state=random_state)
for x in self:
x.set(train, [])
Expand Down
6 changes: 3 additions & 3 deletions kgcnn/layers/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from keras_core import ops
from kgcnn.ops.core import repeat_static_length
from kgcnn.ops.scatter import scatter_reduce_sum
# from keras_core.backend import backend
from keras_core.backend import is_keras_tensor as is_tensor


def pad_left(t):
Expand Down Expand Up @@ -97,7 +97,7 @@ def call(self, inputs: list, **kwargs):
- nodes_count (Tensor): Tensor of number of nodes for each graph of shape `(batch, )` .
- edges_count (Tensor): Tensor of number of edges for each graph of shape `(batch, )` .
"""
all_tensor = all([ops.is_tensor(x) for x in inputs])
all_tensor = all([is_tensor(x) for x in inputs])

nodes, edge_indices, node_len, edge_len = inputs

Expand Down Expand Up @@ -247,7 +247,7 @@ def call(self, inputs: list, **kwargs):
- item_id (Tensor):
- item_counts (Tensor): Tensor of lengths for each graph of shape `(batch, )` .
"""
all_tensor = all([ops.is_tensor(x) for x in inputs])
all_tensor = all([is_tensor(x) for x in inputs])

# Case: Ragged Tensor input.
# As soon as ragged tensors are supported by Keras-Core. We will add this here to simply extract the disjoint
Expand Down
8 changes: 2 additions & 6 deletions kgcnn/layers/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,9 @@
class NodePosition(Layer):
r"""Get node position for directed edges via node indices.
Directly calls :obj:`GatherNodesSelection` with provided index tensor.
Directly calls :obj:`GatherNodes` with provided index tensor.
Returns separate node position tensor for each of the indices. Index selection must be provided
in the constructor. Defaults to first two indices of an edge. This layer simply implements:
.. code-block:: python
GatherNodesSelection([0,1])([position, indices])
in the constructor. Defaults to first two indices of an edge.
A distance based edge is defined by two bond indices of the index list of shape `(batch, [M], 2)`
with last dimension of incoming and outgoing node (message passing framework).
Expand Down

0 comments on commit 34d3df4

Please sign in to comment.