Skip to content

Commit

Permalink
refactor general docs string for models
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Jan 2, 2024
1 parent 0c20ad5 commit 64b0aeb
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
4 changes: 2 additions & 2 deletions kgcnn/literature/AttentiveFP/_make.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import keras as ks
from kgcnn.layers.scale import get as get_scaler
from kgcnn.models.utils import update_model_kwargs
from kgcnn.models.casting import template_cast_output, template_cast_list_input
from kgcnn.models.casting import template_cast_output, template_cast_list_input, template_cast_list_input_docs
from keras.backend import backend as backend_to_use
from kgcnn.layers.modules import Input
from ._model import model_disjoint
Expand Down Expand Up @@ -162,4 +162,4 @@ def set_scale(*args, **kwargs):
return model


make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__)
make_model.__doc__ = make_model.__doc__ % (template_cast_list_input_docs, template_cast_output.__doc__)
81 changes: 81 additions & 0 deletions kgcnn/models/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,87 @@ def template_cast_output(model_outputs,
return out


template_cast_list_input_docs = r"""
Template of listed graph input tensors, which should be compatible to previous kgcnn versions and
defines the order as follows: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, ...]` .
Where '...' denotes further mask or ID tensors, which is required for certain input types (see below).
Depending on the model, some inputs may not be used (see model description for information on supported inputs).
For example if the model does not support angles and no graph attribute input, the input becomes:
:obj:`[nodes, edges, edge_indices, ...]` .
In case of crystal graphs lattice and translation information has to be added. This will give a possible input of
:obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice,...]` .
Note that in place of nodes or edges also more than one tensor can be provided, depending on the model, for example
:obj:`[nodes_1, nodes_2, edges_1, edges_2, edge_indices, ...]` .
However, for future models we intend to used named inputs rather than a list that is sensible to ordering.
Whether to use mask or length tensor for padded as well as further parameter of casting has to be set with
(dict) :obj:`cast_disjoint_kwargs` .
Padded or Masked Inputs:
list: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice,
node_mask/node_count, edge_mask/edge_count, angle_mask/angle_count]`
- nodes (Tensor): Node attributes of shape `(batch, N, F)` or `(batch, N)`
using an embedding layer.
- edges (Tensor): Edge attributes of shape `(batch, M, F)` or `(batch, M)`
using an embedding layer.
- angles (Tensor): Angle attributes of shape `(batch, M, F)` or `(batch, K)`
using an embedding layer.
- edge_indices (Tensor): Index list for edges of shape `(batch, M, 2)` referring to nodes.
- angle_indices (Tensor): Index list for angles of shape `(batch, K, 2)` referring to edges.
- graph_state (Tensor): Graph attributes of shape `(batch, F)` .
- image_translation (Tensor): Indices of the periodic image the sending node is located in.
Shape is `(batch, M, 3)` .
- lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` .
- node_mask (Tensor): Mask for padded nodes of shape `(batch, N)` .
- edge_mask (Tensor): Mask for padded edges of shape `(batch, M)` .
- angle_mask (Tensor): Mask for padded angles of shape `(batch, K)` .
- node_count (Tensor): Total number of nodes if padding is used of shape `(batch, )` .
- edge_count (Tensor): Total number of edges if padding is used of shape `(batch, )` .
- angle_count (Tensor): Total number of angle if padding is used of shape `(batch, )` .
Ragged or Jagged Inputs:
list: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice]`
- nodes (RaggedTensor): Node attributes of shape `(batch, None, F)` or `(batch, None)`
using an embedding layer.
- edges (RaggedTensor): Edge attributes of shape `(batch, None, F)` or `(batch, None)`
using an embedding layer.
- angles (RaggedTensor): Angle attributes of shape `(batch, None, F)` or `(batch, None)`
using an embedding layer.
- edge_indices (RaggedTensor): Index list for edges of shape `(batch, None, 2)` referring to nodes.
- angle_indices (RaggedTensor): Index list for angles of shape `(batch, None, 2)` referring to edges.
- graph_state (Tensor): Graph attributes of shape `(batch, F)` .
- image_translation (RaggedTensor): Indices of the periodic image the sending node is located in.
Shape is `(batch, None, 3)` .
- lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` .
Disjoint Input:
list: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice,
graph_id_node, graph_id_edge, graph_id_angle, nodes_id, edges_id, angle_id, nodes_count, edges_count,
angles_count]`
- nodes (Tensor): Node attributes of shape `([N], F)` or `([N], )` using an embedding layer.
- edges (Tensor): Edge attributes of shape `([M], F)` or `([M], )` using an embedding layer.
- angles (Tensor): Angle attributes of shape `([K], F)` or `([K], )` using an embedding layer.
- edge_indices (Tensor): Index list for edges of shape `(2, [M])` referring to nodes.
- angle_indices (Tensor): Index list for angles of shape `(2, [K])` referring to edges.
- graph_state (Tensor): Graph attributes of shape `(batch, F)` .
- image_translation (Tensor): Indices of the periodic image the sending node is located in.
Shape is `([M], 3)` .
- lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` .
- graph_id_node (Tensor): ID tensor of batch assignment in disjoint graph of shape `([N], )` .
- graph_id_edge (Tensor): ID tensor of batch assignment in disjoint graph of shape `([M], )` .
- graph_id_angle (Tensor): ID tensor of batch assignment in disjoint graph of shape `([K], )` .
- nodes_id (Tensor): The ID-tensor to assign each node to its respective graph of shape `([N], )` .
- edges_id (Tensor): The ID-tensor to assign each edge to its respective graph of shape `([M], )` .
- angle_id (Tensor): The ID-tensor to assign each edge to its respective graph of shape `([K], )` .
- nodes_count (Tensor): Tensor of number of nodes for each graph of shape `(batch, )` .
- edges_count (Tensor): Tensor of number of edges for each graph of shape `(batch, )` .
- angles_count (Tensor): Tensor of number of angles for each graph of shape `(batch, )` .
"""


def template_cast_list_input(model_inputs,
input_tensor_type,
cast_disjoint_kwargs,
Expand Down

0 comments on commit 64b0aeb

Please sign in to comment.