Skip to content

Commit

Permalink
update for keras 3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
PatReis committed Nov 16, 2023
1 parent 86b3670 commit d28beb7
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions kgcnn/models/casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,27 @@

def template_cast_output(model_outputs,
output_embedding, output_tensor_type, input_tensor_type, cast_disjoint_kwargs):
"""TODO"""
r"""The standard model output template returns a single tensor of either "graph", "node", or "edge"
embeddings specified by :obj:`output_embedding` within the model.
The return tensor type is determined by :obj:`output_tensor_type` . Options are:
graph:
Tensor: Graph labels of shape `(batch, F)` .
nodes:
Tensor: Node labels for the graph of either type:
- ragged (RaggedTensor): Single tensor of shape `(batch, None, F)` .
- padded (Tensor): Padded tensor of shape `(batch, N, F)` .
- disjoint (Tensor): Disjoint representation of shape `([N], F)` .
edges:
Tensor: Edge labels for the graph of either type:
- ragged (RaggedTensor): Single tensor of shape `(batch, None, F)` .
- padded (Tensor): Padded tensor of shape `(batch, M, F)`
- disjoint (Tensor): Disjoint representation of shape `([M], F)` .
"""

out, batch_id_node, batch_id_edge, node_id, edge_id, count_nodes, count_edges = model_outputs

Expand Down Expand Up @@ -68,6 +88,8 @@ def template_cast_list_input(model_inputs,
Note that in place of nodes or edges also more than one tensor can be provided, depending on the model, for example
:obj:`[nodes_1, nodes_2, edges_1, edges_2, edge_indices, ...]` .
However, for future models we intend to used named inputs rather than a list that is sensible to ordering.
Whether to use mask or length tensor for padded as well as further parameter of casting has to be set with
(dict) :obj:`cast_disjoint_kwargs` .
Padded or Masked Inputs:
list: :obj:`[nodes, edges, angles, edge_indices, angle_indices, graph_state, image_translation, lattice,
Expand All @@ -83,7 +105,7 @@ def template_cast_list_input(model_inputs,
- angle_indices (Tensor): Index list for angles of shape `(batch, K, 2)` referring to edges.
- graph_state (Tensor): Graph attributes of shape `(batch, F)` .
- image_translation (Tensor): Indices of the periodic image the sending node is located in.
Shape is `(batch, M, 3)` .
Shape is `(batch, M, 3)` .
- lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` .
- node_mask (Tensor): Mask for padded nodes of shape `(batch, N)` .
- edge_mask (Tensor): Mask for padded edges of shape `(batch, M)` .
Expand All @@ -105,7 +127,7 @@ def template_cast_list_input(model_inputs,
- angle_indices (RaggedTensor): Index list for angles of shape `(batch, None, 2)` referring to edges.
- graph_state (Tensor): Graph attributes of shape `(batch, F)` .
- image_translation (RaggedTensor): Indices of the periodic image the sending node is located in.
Shape is `(batch, None, 3)` .
Shape is `(batch, None, 3)` .
- lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` .
Disjoint Input:
Expand All @@ -120,7 +142,7 @@ def template_cast_list_input(model_inputs,
- angle_indices (Tensor): Index list for angles of shape `(2, [K])` referring to edges.
- graph_state (Tensor): Graph attributes of shape `(batch, F)` .
- image_translation (Tensor): Indices of the periodic image the sending node is located in.
Shape is `([M], 3)` .
Shape is `([M], 3)` .
- lattice (Tensor): Lattice matrix of the periodic structure of shape `(batch, 3, 3)` .
- graph_id_node (Tensor): ID tensor of graph assignment in disjoint graph of shape `([N], )` .
- graph_id_edge (Tensor): ID tensor of graph assignment in disjoint graph of shape `([M], )` .
Expand Down

0 comments on commit d28beb7

Please sign in to comment.