Skip to content

Commit

Permalink
Address Marc's feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
Miltos Allamanis committed Feb 22, 2021
1 parent b42df77 commit 8dca26f
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions ptgnn/neuralmodels/gnn/graphneuralnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,16 @@ def tensorize(
else:
tensorized_edge_features = []
for edge_type in self.__edge_idx_to_type:
edge_features = datapoint.edge_features.get(edge_type)
if edge_features is None:
edge_features_for_edge_type = datapoint.edge_features.get(edge_type)
if edge_features_for_edge_type is None:
# No edges of type `edge_type`
tensorized_edge_features.append([])
else:
tensorized_edge_features.append(
[self.__edge_embedding_model.tensorize(e) for e in edge_features]
[
self.__edge_embedding_model.tensorize(e)
for e in edge_features_for_edge_type
]
)

tensorized_data = TensorizedGraphData(
Expand Down Expand Up @@ -371,7 +375,7 @@ def initialize_minibatch(self) -> Dict[str, Any]:
"edge_feature_data": [
self.__edge_embedding_model.initialize_minibatch()
if self.__edge_embedding_model is not None
else []
else None
for _ in range(len(self.__edge_types))
],
"num_nodes_per_graph": [],
Expand All @@ -397,9 +401,9 @@ def extend_minibatch_with(
tensorized_edge_feature_data = partial_minibatch["edge_feature_data"]
nodes_in_mb_so_far = partial_minibatch["num_nodes_in_mb"]

all_edge_features = tensorized_datapoint.edge_features
if all_edge_features is None:
all_edge_features = [None for _ in range(len(adj_list))]
datapoint_edge_features = tensorized_datapoint.edge_features
if datapoint_edge_features is None:
datapoint_edge_features = [None for _ in range(len(adj_list))]

for (
sample_adj_list_for_edge_type,
Expand All @@ -408,7 +412,7 @@ def extend_minibatch_with(
mb_edge_feature_data,
) in zip(
tensorized_datapoint.adjacency_lists,
all_edge_features,
datapoint_edge_features,
adj_list,
tensorized_edge_feature_data,
):
Expand Down

0 comments on commit 8dca26f

Please sign in to comment.