From 64b0aebeaffcd98be0e5f996fb38b30daa51777b Mon Sep 17 00:00:00 2001 From: PatReis Date: Tue, 2 Jan 2024 13:28:27 +0100 Subject: [PATCH] refactor general docs string for models --- kgcnn/literature/AttentiveFP/_make.py | 4 +- kgcnn/models/casting.py | 81 +++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/kgcnn/literature/AttentiveFP/_make.py b/kgcnn/literature/AttentiveFP/_make.py index 3118d999..322b9d62 100644 --- a/kgcnn/literature/AttentiveFP/_make.py +++ b/kgcnn/literature/AttentiveFP/_make.py @@ -1,7 +1,7 @@ import keras as ks from kgcnn.layers.scale import get as get_scaler from kgcnn.models.utils import update_model_kwargs -from kgcnn.models.casting import template_cast_output, template_cast_list_input +from kgcnn.models.casting import template_cast_output, template_cast_list_input, template_cast_list_input_docs from keras.backend import backend as backend_to_use from kgcnn.layers.modules import Input from ._model import model_disjoint @@ -162,4 +162,4 @@ def set_scale(*args, **kwargs): return model -make_model.__doc__ = make_model.__doc__ % (template_cast_list_input.__doc__, template_cast_output.__doc__) +make_model.__doc__ = make_model.__doc__ % (template_cast_list_input_docs, template_cast_output.__doc__) diff --git a/kgcnn/models/casting.py b/kgcnn/models/casting.py index ec20b571..581b637a 100644 --- a/kgcnn/models/casting.py +++ b/kgcnn/models/casting.py @@ -71,6 +71,87 @@ def template_cast_output(model_outputs, return out +template_cast_list_input_docs = r""" + Template of listed graph input tensors, which should be compatible to previous kgcnn versions and + defines the order as follows: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, ...]` . + Where '...' denotes further mask or ID tensors, which is required for certain input types (see below). + Depending on the model, some inputs may not be used (see model description for information on supported inputs). + For example if the model does not support angles and no graph attribute input, the input becomes: + :obj:`[nodes, edges, edge_indices, ...]` . + In case of crystal graphs lattice and translation information has to be added. This will give a possible input of + :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice,...]` . + Note that in place of nodes or edges also more than one tensor can be provided, depending on the model, for example + :obj:`[nodes_1, nodes_2, edges_1, edges_2, edge_indices, ...]` . + + However, for future models we intend to used named inputs rather than a list that is sensible to ordering. + Whether to use mask or length tensor for padded as well as further parameter of casting has to be set with + (dict) :obj:`cast_disjoint_kwargs` . + + Padded or Masked Inputs: + list: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice, + node_mask/node_count, edge_mask/edge_count, angle_mask/angle_count]` + + - nodes (Tensor): Node attributes of shape `(batch, N, F)` or `(batch, N)` + using an embedding layer. + - edges (Tensor): Edge attributes of shape `(batch, M, F)` or `(batch, M)` + using an embedding layer. + - angles (Tensor): Angle attributes of shape `(batch, M, F)` or `(batch, K)` + using an embedding layer. + - edge_indices (Tensor): Index list for edges of shape `(batch, M, 2)` referring to nodes. + - angle_indices (Tensor): Index list for angles of shape `(batch, K, 2)` referring to edges. + - graph_state (Tensor): Graph attributes of shape `(batch, F)` . + - image_translation (Tensor): Indices of the periodic image the sending node is located in. + Shape is `(batch, M, 3)` . + - lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` . + - node_mask (Tensor): Mask for padded nodes of shape `(batch, N)` . + - edge_mask (Tensor): Mask for padded edges of shape `(batch, M)` . + - angle_mask (Tensor): Mask for padded angles of shape `(batch, K)` . + - node_count (Tensor): Total number of nodes if padding is used of shape `(batch, )` . + - edge_count (Tensor): Total number of edges if padding is used of shape `(batch, )` . + - angle_count (Tensor): Total number of angle if padding is used of shape `(batch, )` . + + Ragged or Jagged Inputs: + list: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice]` + + - nodes (RaggedTensor): Node attributes of shape `(batch, None, F)` or `(batch, None)` + using an embedding layer. + - edges (RaggedTensor): Edge attributes of shape `(batch, None, F)` or `(batch, None)` + using an embedding layer. + - angles (RaggedTensor): Angle attributes of shape `(batch, None, F)` or `(batch, None)` + using an embedding layer. + - edge_indices (RaggedTensor): Index list for edges of shape `(batch, None, 2)` referring to nodes. + - angle_indices (RaggedTensor): Index list for angles of shape `(batch, None, 2)` referring to edges. + - graph_state (Tensor): Graph attributes of shape `(batch, F)` . + - image_translation (RaggedTensor): Indices of the periodic image the sending node is located in. + Shape is `(batch, None, 3)` . + - lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` . + + Disjoint Input: + list: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice, + graph_id_node, graph_id_edge, graph_id_angle, nodes_id, edges_id, angle_id, nodes_count, edges_count, + angles_count]` + + - nodes (Tensor): Node attributes of shape `([N], F)` or `([N], )` using an embedding layer. + - edges (Tensor): Edge attributes of shape `([M], F)` or `([M], )` using an embedding layer. + - angles (Tensor): Angle attributes of shape `([K], F)` or `([K], )` using an embedding layer. + - edge_indices (Tensor): Index list for edges of shape `(2, [M])` referring to nodes. + - angle_indices (Tensor): Index list for angles of shape `(2, [K])` referring to edges. + - graph_state (Tensor): Graph attributes of shape `(batch, F)` . + - image_translation (Tensor): Indices of the periodic image the sending node is located in. + Shape is `([M], 3)` . + - lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` . + - graph_id_node (Tensor): ID tensor of batch assignment in disjoint graph of shape `([N], )` . + - graph_id_edge (Tensor): ID tensor of batch assignment in disjoint graph of shape `([M], )` . + - graph_id_angle (Tensor): ID tensor of batch assignment in disjoint graph of shape `([K], )` . + - nodes_id (Tensor): The ID-tensor to assign each node to its respective graph of shape `([N], )` . + - edges_id (Tensor): The ID-tensor to assign each edge to its respective graph of shape `([M], )` . + - angle_id (Tensor): The ID-tensor to assign each edge to its respective graph of shape `([K], )` . + - nodes_count (Tensor): Tensor of number of nodes for each graph of shape `(batch, )` . + - edges_count (Tensor): Tensor of number of edges for each graph of shape `(batch, )` . + - angles_count (Tensor): Tensor of number of angles for each graph of shape `(batch, )` . +""" + + def template_cast_list_input(model_inputs, input_tensor_type, cast_disjoint_kwargs,