diff --git a/ptgnn/baseneuralmodel/abstractneuralmodel.py b/ptgnn/baseneuralmodel/abstractneuralmodel.py index ac4994a..38f3dc4 100644 --- a/ptgnn/baseneuralmodel/abstractneuralmodel.py +++ b/ptgnn/baseneuralmodel/abstractneuralmodel.py @@ -152,7 +152,7 @@ def build_neural_module(self) -> TNeuralModule: # region Saving/Loading def save(self, path: Path, model: TNeuralModule) -> None: - os.makedirs(os.path.dirname(str(path)), exist_ok=True) + os.makedirs(os.path.dirname(str(path.absolute())), exist_ok=True) with gzip.open(path, "wb") as f: torch.save((self, model), f) diff --git a/ptgnn/implementations/typilus/graph2class.py b/ptgnn/implementations/typilus/graph2class.py index 596b137..70305d3 100644 --- a/ptgnn/implementations/typilus/graph2class.py +++ b/ptgnn/implementations/typilus/graph2class.py @@ -107,7 +107,7 @@ class Graph2Class( ): def __init__( self, - gnn_model: GraphNeuralNetworkModel[str, Any], + gnn_model: GraphNeuralNetworkModel, max_num_classes: int = 100, try_simplify_unks: bool = True, ): @@ -118,7 +118,7 @@ def __init__( self.__tensorize_samples_with_no_annotation = False self.__tensorize_keep_original_supernode_idx = False - def __convert(self, typilus_graph: TypilusGraph) -> Tuple[GraphData[str], List[str]]: + def __convert(self, typilus_graph: TypilusGraph) -> Tuple[GraphData[str, None], List[str]]: def get_adj_list(adjacency_dict): for from_node_idx, to_node_idxs in adjacency_dict.items(): from_node_idx = int(from_node_idx) @@ -149,7 +149,7 @@ def get_adj_list(adjacency_dict): supernode_annotations.append(enforce_not_None(supernode_data["annotation"])) return ( - GraphData[str]( + GraphData[str, None]( node_information=typilus_graph["nodes"], edges=edges, reference_nodes={ diff --git a/ptgnn/implementations/typilus/train.py b/ptgnn/implementations/typilus/train.py index ac5cdb5..c67898f 100644 --- a/ptgnn/implementations/typilus/train.py +++ b/ptgnn/implementations/typilus/train.py @@ -112,6 +112,7 @@ def create_mlp_mp_layers(num_edges: int): introduce_backwards_edges=True, add_self_edges=True, stop_extending_minibatch_after_num_nodes=120000, + edge_dropout_rate=0.0, ), max_num_classes=100, ) @@ -149,9 +150,12 @@ def run(arguments): initialize_metadata = True restore_path = arguments.get("--restore-path", None) - if restore_path or (arguments["--aml"] and model_path.exists()): + if restore_path: initialize_metadata = False model, nn = Graph2Class.restore_model(Path(restore_path)) + elif arguments["--aml"] and model_path.exists(): + initialize_metadata = False + model, nn = Graph2Class.restore_model(model_path) else: nn = None model = create_graph2class_gnn_model() diff --git a/ptgnn/neuralmodels/gnn/graphneuralnetwork.py b/ptgnn/neuralmodels/gnn/graphneuralnetwork.py index f5e5644..0b63244 100644 --- a/ptgnn/neuralmodels/gnn/graphneuralnetwork.py +++ b/ptgnn/neuralmodels/gnn/graphneuralnetwork.py @@ -36,6 +36,8 @@ def __init__( node_embedder: nn.Module, introduce_backwards_edges: bool, add_self_edges: bool, + edge_dropout_rate: float = 0.0, + edge_feature_embedder: Optional[nn.Module] = None, ): """ :param message_passing_layers: A list of message passing layers. @@ -43,12 +45,16 @@ def __init__( :param introduce_backwards_edges: If `True` special backwards edges should be automatically created. :param add_self_edges: If `True` self-edges will be added. These edges connect the same node across multiple timesteps. + :param edge_dropout_rate: remove random pct of edges """ super().__init__() self.__message_passing_layers = nn.ModuleList(message_passing_layers) self.__node_embedder = node_embedder self.__introduce_backwards_edges = introduce_backwards_edges self.__add_self_edges = add_self_edges + assert 0 <= edge_dropout_rate < 1 + self.__edge_dropout_rate = edge_dropout_rate + self.__edge_feature_embedder = edge_feature_embedder @property def input_node_state_dim(self) -> int: @@ -78,6 +84,7 @@ def gnn( self, node_representations: torch.Tensor, adjacency_lists: List[Tuple[torch.Tensor, torch.Tensor]], + edge_feature_embeddings: List[torch.Tensor], node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], @@ -88,12 +95,29 @@ def gnn( :param adjacency_lists: a list of [num_edges_per_type, 2] adjacency lists per edge type. The order is fixed across runs. Backwards edges and self-edges are included if the appropriate hyperparameter is set. + :param edge_feature_embeddings: a list of the edge features per edge-type. :param node_to_graph_idx: A mapping that tells us which graph the node belongs to :param reference_node_ids: A dictionary indicating the reference node index :param reference_node_graph_idx: A dictionary indicating the graph index for reference node :param return_all_states: Whether to return all states :return: a [num_nodes, output_hidden_dimension] matrix of the output representations """ + if self.__edge_dropout_rate > 0 and self.training: + dropped_adj_list, dropped_edge_features = [], [] + for (edge_sources_idxs, edge_target_idxs), edge_features in zip( + adjacency_lists, edge_feature_embeddings + ): + mask = ( + torch.rand_like(edge_sources_idxs, dtype=torch.float32) + > self.__edge_dropout_rate + ) + dropped_adj_list.append( + (edge_sources_idxs.masked_select(mask), edge_target_idxs.masked_select(mask)) + ) + dropped_edge_features.append(edge_features[mask]) + adjacency_lists = dropped_adj_list + edge_feature_embeddings = dropped_edge_features + all_states = [node_representations] for mp_layer_idx, mp_layer in enumerate(self.__message_passing_layers): node_representations = mp_layer( @@ -102,6 +126,7 @@ def gnn( node_to_graph_idx=node_to_graph_idx, reference_node_ids=reference_node_ids, reference_node_graph_idx=reference_node_graph_idx, + edge_features=edge_feature_embeddings, ) all_states.append(node_representations) if return_all_states: @@ -111,8 +136,9 @@ def gnn( def forward( self, *, - node_data: torch.Tensor, + node_data, adjacency_lists: List[Tuple[torch.Tensor, torch.Tensor]], + edge_feature_data: List, node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], @@ -121,8 +147,10 @@ def forward( ) -> GnnOutput: """ - :param node_data: A [num_nodes, D_input] matrix of the initial node representations. + :param node_data: The data for the node embedder to compute the initial node representations. :param adjacency_lists: A list of [num_edges, 2] matrices for each edge type. + :param edge_feature_data: A list of the same size as `adjacency_lists` with the data + to compute the edge features. :param node_to_graph_idx: A [num_nodes] vector that contains the index of the graph it belongs to. :param reference_node_ids: A dictionary with values the indices of the reference nodes. :param reference_node_graph_idx: A dictionary with values the index of the graph that each @@ -131,18 +159,36 @@ def forward( """ initial_node_representations = self.__node_embedder(**node_data) # [num_nodes, D] + if self.__edge_feature_embedder is None: + edge_feature_embeddings = [ + torch.empty(f.shape[0], 0, device=node_to_graph_idx.device) + for f, _ in adjacency_lists + ] + else: + edge_feature_embeddings = [ + self.__edge_feature_embedder(**edge_data) for edge_data in edge_feature_data + ] + if self.__introduce_backwards_edges: adjacency_lists += [(t, f) for f, t in adjacency_lists] + edge_feature_embeddings += [e for e in edge_feature_embeddings] + if self.__add_self_edges: - num_nodes = initial_node_representations.shape[0] - idents = torch.arange( - num_nodes, dtype=torch.int64, device=initial_node_representations.device - ) + num_nodes = node_to_graph_idx.shape[0] + idents = torch.arange(num_nodes, dtype=torch.int64, device=node_to_graph_idx.device) adjacency_lists.append((idents, idents)) + edge_feature_embeddings.append( + torch.zeros( + num_nodes, + edge_feature_embeddings[-1].shape[-1], + device=node_to_graph_idx.device, + ) + ) output_representations = self.gnn( initial_node_representations, adjacency_lists, + edge_feature_embeddings, node_to_graph_idx, reference_node_ids, reference_node_graph_idx, @@ -152,7 +198,7 @@ def forward( with torch.no_grad(): self.__num_edges += int(sum(adj[0].shape[0] for adj in adjacency_lists)) self.__num_graphs += int(num_graphs) - self.__num_nodes += int(initial_node_representations.shape[0]) + self.__num_nodes += int(node_to_graph_idx.shape[0]) return GnnOutput( input_node_representations=initial_node_representations, output_node_representations=output_representations, @@ -164,13 +210,15 @@ def forward( TNodeData = TypeVar("TNodeData") +TEdgeData = TypeVar("TEdgeData") TTensorizedNodeData = TypeVar("TTensorizedNodeData") +TTensorizedEdgeData = TypeVar("TTensorizedEdgeData") class GraphNeuralNetworkModel( AbstractNeuralModel[ - GraphData[TNodeData], - TensorizedGraphData[TTensorizedNodeData], + GraphData[TNodeData, TEdgeData], + TensorizedGraphData[TTensorizedNodeData, TTensorizedEdgeData], GraphNeuralNetwork, ], ): @@ -185,7 +233,11 @@ def __init__( max_graph_edges: int = 100000, introduce_backwards_edges: bool = True, stop_extending_minibatch_after_num_nodes: int = 10000, - add_self_edges: bool = False + add_self_edges: bool = False, + edge_dropout_rate: float = 0.0, + edge_representation_model: Optional[ + AbstractNeuralModel[TEdgeData, TTensorizedEdgeData, nn.Module] + ] = None ): """ :param node_representation_model: A model that can convert the data of each node into their @@ -196,6 +248,7 @@ def __init__( super().__init__() self.__message_passing_layers_creator: Final = message_passing_layer_creator self.__node_embedding_model: Final = node_representation_model + self.__edge_embedding_model: Final = edge_representation_model self.max_nodes_per_graph: Final = max_nodes_per_graph self.max_graph_edges: Final = max_graph_edges self.introduce_backwards_edges: Final = introduce_backwards_edges @@ -203,18 +256,24 @@ def __init__( stop_extending_minibatch_after_num_nodes ) self.add_self_edges: Final = add_self_edges + self.__edge_dropout_rate = edge_dropout_rate # region Metadata Loading def initialize_metadata(self) -> None: self.__edge_types_mdata: Set[str] = set() - def update_metadata_from(self, datapoint: GraphData) -> None: + def update_metadata_from(self, datapoint: GraphData[TNodeData, TEdgeData]) -> None: for node in datapoint.node_information: self.__node_embedding_model.update_metadata_from(node) for edge_type in datapoint.edges: self.__edge_types_mdata.add(edge_type) + if datapoint.edge_features is not None and self.__edge_embedding_model is not None: + for edge_features in datapoint.edge_features.values(): + for edge_feature in edge_features: + self.__edge_embedding_model.update_metadata_from(edge_feature) + def finalize_metadata(self) -> None: self.LOGGER.info("Found %s edge types in data.", len(self.__edge_types_mdata)) self.__edge_idx_to_type = tuple(self.__edge_types_mdata) @@ -231,11 +290,18 @@ def _num_edge_types(self) -> int: return num_types def build_neural_module(self) -> GraphNeuralNetwork: + if self.__edge_embedding_model is None: + edge_feature_embedder = None + else: + edge_feature_embedder = self.__edge_embedding_model.build_neural_module() + gnn = GraphNeuralNetwork( self.__message_passing_layers_creator(self._num_edge_types), node_embedder=self.__node_embedding_model.build_neural_module(), introduce_backwards_edges=self.introduce_backwards_edges, add_self_edges=self.add_self_edges, + edge_dropout_rate=self.__edge_dropout_rate, + edge_feature_embedder=edge_feature_embedder, ) del self.__message_passing_layers_creator return gnn @@ -246,7 +312,7 @@ def edge_idx_by_name(self, name: str) -> int: return self.__edge_types[name] def __iterate_edge_types( - self, data_to_load: GraphData[TNodeData] + self, data_to_load: GraphData[TNodeData, TEdgeData] ) -> Iterator[Tuple[np.ndarray, np.ndarray]]: for edge_type in self.__edge_idx_to_type: adjacency_list = data_to_load.edges.get(edge_type) @@ -257,14 +323,36 @@ def __iterate_edge_types( yield np.zeros(0, dtype=np.int32), np.zeros(0, dtype=np.int32) def tensorize( - self, datapoint: GraphData[TNodeData] - ) -> Optional[TensorizedGraphData[TTensorizedNodeData]]: + self, datapoint: GraphData[TNodeData, TEdgeData] + ) -> Optional[TensorizedGraphData[TTensorizedNodeData, TTensorizedEdgeData]]: + if len(datapoint.node_information) > self.max_nodes_per_graph: + self.LOGGER.warning("Dropping graph with %s nodes." % len(datapoint.node_information)) + return None + + if self.__edge_embedding_model is None: + tensorized_edge_features = None + else: + tensorized_edge_features = [] + for edge_type in self.__edge_idx_to_type: + edge_features_for_edge_type = datapoint.edge_features.get(edge_type) + if edge_features_for_edge_type is None: + # No edges of type `edge_type` + tensorized_edge_features.append([]) + else: + tensorized_edge_features.append( + [ + self.__edge_embedding_model.tensorize(e) + for e in edge_features_for_edge_type + ] + ) + tensorized_data = TensorizedGraphData( adjacency_lists=list(self.__iterate_edge_types(datapoint)), node_tensorized_data=[ enforce_not_None(self.__node_embedding_model.tensorize(ni)) for ni in datapoint.node_information ], + edge_features=tensorized_edge_features, reference_nodes={ n: np.array(np.array(refs, dtype=np.int32)) for n, refs in datapoint.reference_nodes.items() @@ -272,10 +360,6 @@ def tensorize( num_nodes=len(datapoint.node_information), ) - if tensorized_data.num_nodes > self.max_nodes_per_graph: - self.LOGGER.warning("Dropping graph with %s nodes." % tensorized_data.num_nodes) - return None - num_edges = sum(len(adj) for adj in tensorized_data.adjacency_lists) if num_edges > self.max_graph_edges: self.LOGGER.warning("Dropping graph with %s edges." % num_edges) @@ -288,6 +372,12 @@ def initialize_minibatch(self) -> Dict[str, Any]: return { "node_data_mb": self.__node_embedding_model.initialize_minibatch(), "adjacency_lists": [([], []) for _ in range(len(self.__edge_types))], + "edge_feature_data": [ + self.__edge_embedding_model.initialize_minibatch() + if self.__edge_embedding_model is not None + else None + for _ in range(len(self.__edge_types)) + ], "num_nodes_per_graph": [], "reference_node_graph_idx": defaultdict(list), "reference_node_ids": defaultdict(list), @@ -295,7 +385,9 @@ def initialize_minibatch(self) -> Dict[str, Any]: } def extend_minibatch_with( - self, tensorized_datapoint: TensorizedGraphData, partial_minibatch: Dict[str, Any] + self, + tensorized_datapoint: TensorizedGraphData[TTensorizedNodeData, TTensorizedEdgeData], + partial_minibatch: Dict[str, Any], ) -> bool: continue_extending = True for node_tensorized_info in tensorized_datapoint.node_tensorized_data: @@ -306,9 +398,23 @@ def extend_minibatch_with( graph_idx = len(partial_minibatch["num_nodes_per_graph"]) adj_list = partial_minibatch["adjacency_lists"] + tensorized_edge_feature_data = partial_minibatch["edge_feature_data"] nodes_in_mb_so_far = partial_minibatch["num_nodes_in_mb"] - for sample_adj_list_for_edge_type, mb_adj_lists_for_edge_type in zip( - tensorized_datapoint.adjacency_lists, adj_list + + datapoint_edge_features = tensorized_datapoint.edge_features + if datapoint_edge_features is None: + datapoint_edge_features = [None for _ in range(len(adj_list))] + + for ( + sample_adj_list_for_edge_type, + edge_features, + mb_adj_lists_for_edge_type, + mb_edge_feature_data, + ) in zip( + tensorized_datapoint.adjacency_lists, + datapoint_edge_features, + adj_list, + tensorized_edge_feature_data, ): mb_adj_lists_for_edge_type[0].append( sample_adj_list_for_edge_type[0] + nodes_in_mb_so_far @@ -316,6 +422,11 @@ def extend_minibatch_with( mb_adj_lists_for_edge_type[1].append( sample_adj_list_for_edge_type[1] + nodes_in_mb_so_far ) + if self.__edge_embedding_model is not None: + for edge_feature in edge_features: + self.__edge_embedding_model.extend_minibatch_with( + edge_feature, mb_edge_feature_data + ) for ref_name, ref_nodes in tensorized_datapoint.reference_nodes.items(): partial_minibatch["reference_node_graph_idx"][ref_name].extend( @@ -335,6 +446,15 @@ def __create_node_to_graph_idx(num_nodes_per_graph: List[int]) -> Iterable[int]: def finalize_minibatch( self, accumulated_minibatch_data: Dict[str, Any], device: Union[str, torch.device] ) -> Dict[str, Any]: + + if self.__edge_embedding_model is None: + edge_feature_data = [None for _ in accumulated_minibatch_data["edge_feature_data"]] + else: + edge_feature_data = [ + self.__edge_embedding_model.finalize_minibatch(edge_features_for_type, device) + for edge_features_for_type in accumulated_minibatch_data["edge_feature_data"] + ] + return { "node_data": self.__node_embedding_model.finalize_minibatch( accumulated_minibatch_data["node_data_mb"], device @@ -346,6 +466,7 @@ def finalize_minibatch( ) for adjFrom, adjTo in accumulated_minibatch_data["adjacency_lists"] ], + "edge_feature_data": edge_feature_data, "node_to_graph_idx": torch.tensor( list( self.__create_node_to_graph_idx( diff --git a/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py b/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py index eca7948..5ce698c 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py @@ -16,6 +16,7 @@ def forward( node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], + edge_features: List[torch.Tensor], ) -> torch.Tensor: """ :param node_states: A [num_nodes, D] matrix containing the states of all nodes. @@ -23,6 +24,8 @@ def forward( :param node_to_graph_idx: :param reference_node_ids: :param reference_node_graph_idx: + :param edge_features: A list of [num_edges, H] with edge features. + Has the size of `adjacency_lists`. :return: the next node states in a [num_nodes, D'] matrix. """ diff --git a/ptgnn/neuralmodels/gnn/messagepassing/gatedmessagepassing.py b/ptgnn/neuralmodels/gnn/messagepassing/gatedmessagepassing.py index 9113e1e..af8ab46 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/gatedmessagepassing.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/gatedmessagepassing.py @@ -13,12 +13,13 @@ def __init__( num_edge_types: int, message_aggregation_function: str, dropout_rate: float = 0.0, + edge_feature_dimension: int = 0, ): super().__init__() self.__edge_message_transformation_layers = nn.ModuleList( [ - nn.Linear(state_dimension, message_dimension, bias=False) + nn.Linear(state_dimension + edge_feature_dimension, message_dimension, bias=False) for _ in range(num_edge_types) ] ) @@ -40,20 +41,23 @@ def forward( node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], + edge_features: List[torch.Tensor], ) -> torch.Tensor: message_targets = torch.cat([adj_list[1] for adj_list in adjacency_lists]) # num_messages assert len(adjacency_lists) == len(self.__edge_message_transformation_layers) all_messages = [] - for edge_type_idx, (adj_list, edge_transformation_layer) in enumerate( - zip(adjacency_lists, self.__edge_message_transformation_layers) + for edge_type_idx, (adj_list, features, edge_transformation_layer) in enumerate( + zip(adjacency_lists, edge_features, self.__edge_message_transformation_layers) ): edge_sources_idxs = adj_list[0] edge_source_states = nn.functional.embedding( edge_sources_idxs, node_states ) # [num_edges_of_type_edge_type_idx, H] all_messages.append( - edge_transformation_layer(self.__dropout(edge_source_states)) + edge_transformation_layer( + self.__dropout(torch.cat([edge_source_states, features], -1)) + ) ) # [num_edges_of_type_edge_type_idx, D] aggregated_messages = self._aggregate_messages( diff --git a/ptgnn/neuralmodels/gnn/messagepassing/globalgraphexchange.py b/ptgnn/neuralmodels/gnn/messagepassing/globalgraphexchange.py index 48d5570..88b3ba2 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/globalgraphexchange.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/globalgraphexchange.py @@ -33,6 +33,7 @@ def forward( node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], + edge_features: List[torch.Tensor], ) -> torch.Tensor: e = ElementsToSummaryRepresentationInput( element_embeddings=node_states, diff --git a/ptgnn/neuralmodels/gnn/messagepassing/graphnorm.py b/ptgnn/neuralmodels/gnn/messagepassing/graphnorm.py index f5a519c..560bf82 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/graphnorm.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/graphnorm.py @@ -31,6 +31,7 @@ def forward( node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], + edge_features: List[torch.Tensor], ) -> torch.Tensor: per_graph_mean = scatter_mean( node_states, index=node_to_graph_idx, dim=0 diff --git a/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py b/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py index b6ace14..f5fca5a 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py @@ -21,6 +21,7 @@ def __init__( use_dense_layer: bool = True, dropout_rate: float = 0.0, dense_activation: Optional[nn.Module] = nn.Tanh(), + features_dimension: int = 0, ): super().__init__() self.__input_state_dim = input_state_dimension @@ -34,7 +35,7 @@ def __init__( self.__edge_message_transformation_layers = nn.ModuleList( [ MLP( - input_dimension=message_input_size, + input_dimension=message_input_size + features_dimension, output_dimension=message_dimension, hidden_layers=mlp_hidden_layers, ) @@ -63,12 +64,13 @@ def forward( node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], + edge_features: List[torch.Tensor], ) -> torch.Tensor: assert len(adjacency_lists) == len(self.__edge_message_transformation_layers) all_message_targets, all_messages = [], [] - for edge_type_idx, (adj_list, edge_transformation_layer) in enumerate( - zip(adjacency_lists, self.__edge_message_transformation_layers) + for edge_type_idx, (adj_list, features, edge_transformation_layer) in enumerate( + zip(adjacency_lists, edge_features, self.__edge_message_transformation_layers) ): edge_sources_idxs, edge_target_idxs = adj_list all_message_targets.append(edge_target_idxs) @@ -81,7 +83,9 @@ def forward( else: message_input = edge_source_states - all_messages.append(edge_transformation_layer(message_input)) + all_messages.append( + edge_transformation_layer(torch.cat([message_input, features], dim=-1)) + ) aggregated_messages = self._aggregate_messages( messages=torch.cat(all_messages, dim=0), diff --git a/ptgnn/neuralmodels/gnn/messagepassing/residuallayers.py b/ptgnn/neuralmodels/gnn/messagepassing/residuallayers.py index f716e63..220f1fa 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/residuallayers.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/residuallayers.py @@ -26,6 +26,7 @@ def forward( node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], + edge_features: List[torch.Tensor], ) -> torch.Tensor: self.__target_layer._original_input = node_states return node_states @@ -47,6 +48,7 @@ def forward( node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], + edge_features: List[torch.Tensor], ) -> torch.Tensor: assert self._original_input is not None, "Initial Pass Through Layer was not used." out = torch.stack((self._original_input, node_states), dim=-1).mean(dim=-1) @@ -78,6 +80,7 @@ def forward( node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], + edge_features: List[torch.Tensor], ) -> torch.Tensor: assert self._original_input is not None, "Initial Pass Through Layer was not used." out = torch.cat((self._original_input, node_states), dim=-1) @@ -123,6 +126,7 @@ def forward( node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], + edge_features: List[torch.Tensor], ) -> torch.Tensor: assert self._original_input is not None, "Initial Pass Through Layer was not used." out = self.__linear_combination(torch.cat((self._original_input, node_states), axis=-1)) diff --git a/ptgnn/neuralmodels/gnn/messagepassing/selfattmessagepassing.py b/ptgnn/neuralmodels/gnn/messagepassing/selfattmessagepassing.py index 5aac376..d9e4ded 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/selfattmessagepassing.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/selfattmessagepassing.py @@ -81,6 +81,7 @@ def forward( node_to_graph_idx: torch.Tensor, reference_node_ids: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor], + edge_features: List[torch.Tensor], # Not used ) -> torch.Tensor: if self.__target_reference == "all": relevant_node_states = node_states diff --git a/ptgnn/neuralmodels/gnn/structs.py b/ptgnn/neuralmodels/gnn/structs.py index 9490afd..7089b78 100644 --- a/ptgnn/neuralmodels/gnn/structs.py +++ b/ptgnn/neuralmodels/gnn/structs.py @@ -1,39 +1,51 @@ import numpy as np import torch from abc import ABC, abstractmethod -from typing import Dict, Generic, List, NamedTuple, Tuple, TypeVar +from typing import Dict, Generic, List, NamedTuple, Optional, Tuple, TypeVar TNodeData = TypeVar("TNodeData") +TEdgeData = TypeVar("TEdgeData") TTensorizedNodeData = TypeVar("TTensorizedNodeData") +TTensorizedEdgeData = TypeVar("TTensorizedEdgeData") -class GraphData(Generic[TNodeData]): - __slots__ = ("node_information", "edges", "reference_nodes") +class GraphData(Generic[TNodeData, TEdgeData]): + __slots__ = ("node_information", "edges", "edge_features", "reference_nodes") def __init__( self, node_information: List[TNodeData], edges: Dict[str, List[Tuple[int, int]]], reference_nodes: Dict[str, List[int]], + edge_features: Optional[Dict[str, List[TEdgeData]]] = None, ): self.node_information = node_information self.edges = edges + self.edge_features = edge_features self.reference_nodes = reference_nodes -class TensorizedGraphData(Generic[TTensorizedNodeData]): - __slots__ = ("num_nodes", "node_tensorized_data", "adjacency_lists", "reference_nodes") +class TensorizedGraphData(Generic[TTensorizedNodeData, TTensorizedEdgeData]): + __slots__ = ( + "num_nodes", + "node_tensorized_data", + "adjacency_lists", + "edge_features", + "reference_nodes", + ) def __init__( self, num_nodes: int, node_tensorized_data: List[TTensorizedNodeData], adjacency_lists: List[Tuple[np.ndarray, np.ndarray]], + edge_features: Optional[List[TTensorizedEdgeData]], reference_nodes: Dict[str, np.ndarray], ): self.num_nodes = num_nodes self.node_tensorized_data = node_tensorized_data self.adjacency_lists = adjacency_lists + self.edge_features = edge_features self.reference_nodes = reference_nodes