From d28beb7262b3e972e10117e4baf7f972388d3f4b Mon Sep 17 00:00:00 2001 From: PatReis Date: Thu, 16 Nov 2023 12:20:51 +0100 Subject: [PATCH] update for keras 3.0 --- kgcnn/models/casting.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/kgcnn/models/casting.py b/kgcnn/models/casting.py index c119e6be..90bdc9ab 100644 --- a/kgcnn/models/casting.py +++ b/kgcnn/models/casting.py @@ -9,7 +9,27 @@ def template_cast_output(model_outputs, output_embedding, output_tensor_type, input_tensor_type, cast_disjoint_kwargs): - """TODO""" + r"""The standard model output template returns a single tensor of either "graph", "node", or "edge" + embeddings specified by :obj:`output_embedding` within the model. + The return tensor type is determined by :obj:`output_tensor_type` . Options are: + + graph: + Tensor: Graph labels of shape `(batch, F)` . + + nodes: + Tensor: Node labels for the graph of either type: + + - ragged (RaggedTensor): Single tensor of shape `(batch, None, F)` . + - padded (Tensor): Padded tensor of shape `(batch, N, F)` . + - disjoint (Tensor): Disjoint representation of shape `([N], F)` . + + edges: + Tensor: Edge labels for the graph of either type: + + - ragged (RaggedTensor): Single tensor of shape `(batch, None, F)` . + - padded (Tensor): Padded tensor of shape `(batch, M, F)` + - disjoint (Tensor): Disjoint representation of shape `([M], F)` . + """ out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = model_outputs @@ -68,6 +88,8 @@ def template_cast_list_input(model_inputs, Note that in place of nodes or edges also more than one tensor can be provided, depending on the model, for example :obj:`[nodes_1, nodes_2, edges_1, edges_2, edge_indices, ...]` . However, for future models we intend to used named inputs rather than a list that is sensible to ordering. + Whether to use mask or length tensor for padded as well as further parameter of casting has to be set with + (dict) :obj:`cast_disjoint_kwargs` . Padded or Masked Inputs: list: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice, @@ -83,7 +105,7 @@ def template_cast_list_input(model_inputs, - angle_indices (Tensor): Index list for angles of shape `(batch, K, 2)` referring to edges. - graph_state (Tensor): Graph attributes of shape `(batch, F)` . - image_translation (Tensor): Indices of the periodic image the sending node is located in. - Shape is `(batch, M, 3)` . + Shape is `(batch, M, 3)` . - lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` . - node_mask (Tensor): Mask for padded nodes of shape `(batch, N)` . - edge_mask (Tensor): Mask for padded edges of shape `(batch, M)` . @@ -105,7 +127,7 @@ def template_cast_list_input(model_inputs, - angle_indices (RaggedTensor): Index list for angles of shape `(batch, None, 2)` referring to edges. - graph_state (Tensor): Graph attributes of shape `(batch, F)` . - image_translation (RaggedTensor): Indices of the periodic image the sending node is located in. - Shape is `(batch, None, 3)` . + Shape is `(batch, None, 3)` . - lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` . Disjoint Input: @@ -120,7 +142,7 @@ def template_cast_list_input(model_inputs, - angle_indices (Tensor): Index list for angles of shape `(2, [K])` referring to edges. - graph_state (Tensor): Graph attributes of shape `(batch, F)` . - image_translation (Tensor): Indices of the periodic image the sending node is located in. - Shape is `([M], 3)` . + Shape is `([M], 3)` . - lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` . - graph_id_node (Tensor): ID tensor of graph assignment in disjoint graph of shape `([N], )` . - graph_id_edge (Tensor): ID tensor of graph assignment in disjoint graph of shape `([M], )` .