From a6689b617ab6f98d927ba2a92a3b2f52672c1aea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Wed, 13 Nov 2024 18:03:06 -0800 Subject: [PATCH 1/9] GraphMLP working without additional loss --- configs/model/graph/graph_mlp.yaml | 37 ++++++ configs/run.yaml | 4 +- topobenchmarkx/nn/backbones/__init__.py | 9 +- topobenchmarkx/nn/backbones/graph/__init__.py | 2 + .../nn/backbones/graph/graph_mlp.py | 121 ++++++++++++++++++ topobenchmarkx/nn/wrappers/__init__.py | 3 +- topobenchmarkx/nn/wrappers/graph/__init__.py | 2 + .../nn/wrappers/graph/graph_mlp_wrapper.py | 60 +++++++++ 8 files changed, 234 insertions(+), 4 deletions(-) create mode 100755 configs/model/graph/graph_mlp.yaml create mode 100644 topobenchmarkx/nn/backbones/graph/graph_mlp.py create mode 100644 topobenchmarkx/nn/wrappers/graph/graph_mlp_wrapper.py diff --git a/configs/model/graph/graph_mlp.yaml b/configs/model/graph/graph_mlp.yaml new file mode 100755 index 00000000..d69dd747 --- /dev/null +++ b/configs/model/graph/graph_mlp.yaml @@ -0,0 +1,37 @@ +_target_: topobenchmarkx.model.TBXModel + +model_name: gat +model_domain: graph + +feature_encoder: + _target_: topobenchmarkx.nn.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder + in_channels: ${infer_in_channels:${dataset},${oc.select:transforms,null}} + out_channels: 32 + proj_dropout: 0.0 + +backbone: + _target_: topobenchmarkx.nn.backbones.GraphMLP + in_channels: ${model.feature_encoder.out_channels} + hidden_channels: ${model.feature_encoder.out_channels} + order: 2 + dropout: 0.0 + +backbone_wrapper: + _target_: topobenchmarkx.nn.wrappers.GraphMLPWrapper + _partial_: true + wrapper_name: GraphMLPWrapper + out_channels: ${model.feature_encoder.out_channels} + num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} + +readout: + _target_: topobenchmarkx.nn.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown + num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider + hidden_dim: ${model.feature_encoder.out_channels} + out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} + pooling_type: sum + +# compile model for faster training with pytorch 2.0 +compile: false diff --git a/configs/run.yaml b/configs/run.yaml index 5568de72..d7192338 100755 --- a/configs/run.yaml +++ b/configs/run.yaml @@ -5,14 +5,14 @@ defaults: - _self_ - dataset: graph/cocitation_cora - - model: hypergraph/unignn2 + - model: graph/graph_mlp - transforms: ${get_default_transform:${dataset},${model}} #no_transform - optimizer: default - loss: default - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) - - trainer: default + - trainer: cpu - paths: default - extras: default - hydra: default diff --git a/topobenchmarkx/nn/backbones/__init__.py b/topobenchmarkx/nn/backbones/__init__.py index 37866e59..c2ad4b25 100644 --- a/topobenchmarkx/nn/backbones/__init__.py +++ b/topobenchmarkx/nn/backbones/__init__.py @@ -4,13 +4,20 @@ CCCN, ) from .combinatorial import TopoTune, TopoTune_OneHasse -from .graph import IdentityGAT, IdentityGCN, IdentityGIN, IdentitySAGE +from .graph import ( + GraphMLP, + IdentityGAT, + IdentityGCN, + IdentityGIN, + IdentitySAGE, +) from .hypergraph import EDGNN from .simplicial import SCCNNCustom __all__ = [ "CCCN", "EDGNN", + "GraphMLP", "SCCNNCustom", "TopoTune", "TopoTune_OneHasse", diff --git a/topobenchmarkx/nn/backbones/graph/__init__.py b/topobenchmarkx/nn/backbones/graph/__init__.py index 6f46e69d..59e3996f 100644 --- a/topobenchmarkx/nn/backbones/graph/__init__.py +++ b/topobenchmarkx/nn/backbones/graph/__init__.py @@ -13,6 +13,7 @@ Node2Vec, ) +from .graph_mlp import GraphMLP from .identity_gnn import ( IdentityGAT, IdentityGCN, @@ -25,6 +26,7 @@ "IdentityGIN", "IdentityGAT", "IdentitySAGE", + "GraphMLP", "MLP", "GCN", "GraphSAGE", diff --git a/topobenchmarkx/nn/backbones/graph/graph_mlp.py b/topobenchmarkx/nn/backbones/graph/graph_mlp.py new file mode 100644 index 00000000..0062c535 --- /dev/null +++ b/topobenchmarkx/nn/backbones/graph/graph_mlp.py @@ -0,0 +1,121 @@ +"""Graph MLP backbone from https://github.com/yanghu819/Graph-MLP/blob/master/models.py.""" + +import torch +import torch.nn as nn +from torch.nn import Dropout, LayerNorm, Linear + + +class GraphMLP(nn.Module): + """ "Graph MLP backbone. + + Parameters + ---------- + in_channels : int + Number of input features. + hidden_channels : int + Number of hidden units. + order : int, optional + To compute order-th power of adj matrix (default: 1). + dropout : float, optional + Dropout rate (default: 0.0). + """ + + def __init__(self, in_channels, hidden_channels, order=1, dropout=0.0): + super().__init__() + self.nhid = hidden_channels + self.order = order + self.mlp = Mlp(in_channels, self.nhid, dropout) + + def forward(self, x): + """Forward pass. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor. + """ + x = self.mlp(x) + Z = x + + x_dis = get_feature_dis(Z) if self.training else None + + return x, x_dis + + +class Mlp(nn.Module): + """MLP module. + + Parameters + ---------- + input_dim : int + Input dimension. + hid_dim : int + Hidden dimension. + dropout : float + Dropout rate. + """ + + def __init__(self, input_dim, hid_dim, dropout): + super().__init__() + self.fc1 = Linear(input_dim, hid_dim) + self.fc2 = Linear(hid_dim, hid_dim) + self.act_fn = torch.nn.functional.gelu + self._init_weights() + + self.dropout = Dropout(dropout) + self.layernorm = LayerNorm(hid_dim, eps=1e-6) + + def _init_weights(self): + """Initialize weights.""" + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + nn.init.normal_(self.fc1.bias, std=1e-6) + nn.init.normal_(self.fc2.bias, std=1e-6) + + def forward(self, x): + """Forward pass. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor. + """ + x = self.fc1(x) + x = self.act_fn(x) + x = self.layernorm(x) + x = self.dropout(x) + x = self.fc2(x) + return x + + +def get_feature_dis(x): + """Get feature distance matrix. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Feature distance matrix. + """ + x_dis = x @ x.T + mask = torch.eye(x_dis.shape[0]).to(x_dis.device) + x_sum = torch.sum(x**2, 1).reshape(-1, 1) + x_sum = torch.sqrt(x_sum).reshape(-1, 1) + x_sum = x_sum @ x_sum.T + x_dis = x_dis * (x_sum ** (-1)) + x_dis = (1 - mask) * x_dis + return x_dis diff --git a/topobenchmarkx/nn/wrappers/__init__.py b/topobenchmarkx/nn/wrappers/__init__.py index bff65cec..f1d26851 100755 --- a/topobenchmarkx/nn/wrappers/__init__.py +++ b/topobenchmarkx/nn/wrappers/__init__.py @@ -8,7 +8,7 @@ CWNWrapper, ) from topobenchmarkx.nn.wrappers.combinatorial import TuneWrapper -from topobenchmarkx.nn.wrappers.graph import GNNWrapper +from topobenchmarkx.nn.wrappers.graph import GNNWrapper, GraphMLPWrapper from topobenchmarkx.nn.wrappers.hypergraph import HypergraphWrapper from topobenchmarkx.nn.wrappers.simplicial import ( SANWrapper, @@ -26,6 +26,7 @@ # Export all wrappers __all__ = [ "AbstractWrapper", + "GraphMLPWrapper", "GNNWrapper", "HypergraphWrapper", "SANWrapper", diff --git a/topobenchmarkx/nn/wrappers/graph/__init__.py b/topobenchmarkx/nn/wrappers/graph/__init__.py index 03168781..32b9c4c4 100644 --- a/topobenchmarkx/nn/wrappers/graph/__init__.py +++ b/topobenchmarkx/nn/wrappers/graph/__init__.py @@ -1,8 +1,10 @@ """Wrappers for graph models.""" from .gnn_wrapper import GNNWrapper +from .graph_mlp_wrapper import GraphMLPWrapper # Export all wrappers __all__ = [ "GNNWrapper", + "GraphMLPWrapper", ] diff --git a/topobenchmarkx/nn/wrappers/graph/graph_mlp_wrapper.py b/topobenchmarkx/nn/wrappers/graph/graph_mlp_wrapper.py new file mode 100644 index 00000000..d28bae0d --- /dev/null +++ b/topobenchmarkx/nn/wrappers/graph/graph_mlp_wrapper.py @@ -0,0 +1,60 @@ +"""Wrapper for the GNN models.""" + +import torch +import torch_geometric + +from topobenchmarkx.nn.wrappers.base import AbstractWrapper + + +class GraphMLPWrapper(AbstractWrapper): + r"""Wrapper for the GNN models. + + This wrapper defines the forward pass of the model. The GNN models return + the embeddings of the cells of rank 0. + """ + + def get_power_adj(self, edge_index, order=1): + r"""Get the power of the adjacency matrix. + + Parameters + ---------- + edge_index : torch.Tensor + Edge index tensor. + order : int, optional + Order of the adjacency matrix (default: 1). + + Returns + ------- + torch.Tensor + Power of the adjacency matrix. + """ + edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index) + adj = torch_geometric.utils.to_dense_adj(edge_index) + adj_power = adj.clone() + for _ in range(order - 1): + adj_power = torch.matmul(adj_power, adj) + return adj_power + + def forward(self, batch): + r"""Forward pass for the GNN wrapper. + + Parameters + ---------- + batch : torch_geometric.data.Data + Batch object containing the batched data. + + Returns + ------- + dict + Dictionary containing the updated model output. + """ + x_0, x_dis = self.backbone(batch.x_0) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_0"] = x_0 + model_out["x_dis"] = x_dis + model_out["adj_label"] = self.get_power_adj( + batch.edge_index, self.backbone.order + ) + + return model_out From b8fc8ac1c25a54f6cb98ef1b94812ec04d7d67c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Wed, 13 Nov 2024 18:06:10 -0800 Subject: [PATCH 2/9] Adding GraphMLP loss --- topobenchmarkx/loss/custom_losses/__init__.py | 1 + .../loss/custom_losses/graph_mlp_loss.py | 27 +++++++++++++++++++ 2 files changed, 28 insertions(+) create mode 100644 topobenchmarkx/loss/custom_losses/__init__.py create mode 100644 topobenchmarkx/loss/custom_losses/graph_mlp_loss.py diff --git a/topobenchmarkx/loss/custom_losses/__init__.py b/topobenchmarkx/loss/custom_losses/__init__.py new file mode 100644 index 00000000..164d2f8a --- /dev/null +++ b/topobenchmarkx/loss/custom_losses/__init__.py @@ -0,0 +1 @@ +"""Init file for custom loss module.""" diff --git a/topobenchmarkx/loss/custom_losses/graph_mlp_loss.py b/topobenchmarkx/loss/custom_losses/graph_mlp_loss.py new file mode 100644 index 00000000..b5fc8e14 --- /dev/null +++ b/topobenchmarkx/loss/custom_losses/graph_mlp_loss.py @@ -0,0 +1,27 @@ +"""Graph MLP loss function.""" + +import torch + + +def graph_mlp_contrast_loss(x_dis, adj_label, tau=1): + """Graph MLP contrastive loss. + + Parameters + ---------- + x_dis : torch.Tensor + Distance matrix. + adj_label : torch.Tensor + Adjacency matrix. + tau : float, optional + Temperature parameter (default: 1). + + Returns + ------- + torch.Tensor + Contrastive loss. + """ + x_dis = torch.exp(tau * x_dis) + x_dis_sum = torch.sum(x_dis, 1) + x_dis_sum_pos = torch.sum(x_dis * adj_label, 1) + loss = -torch.log(x_dis_sum_pos * (x_dis_sum ** (-1)) + 1e-8).mean() + return loss From b643e86fea9012522b157064cbb2a2b48fb19737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Wed, 13 Nov 2024 19:19:02 -0800 Subject: [PATCH 3/9] Enabling additional backbone losses (GraphMLP as an example) --- configs/loss/default.yaml | 6 +- configs/model/graph/graph_mlp.yaml | 9 +- .../loss/custom_losses/DatasetLoss.py | 58 +++++++++++ .../loss/custom_losses/GraphMLPLoss.py | 96 +++++++++++++++++++ topobenchmarkx/loss/custom_losses/__init__.py | 8 ++ .../loss/custom_losses/graph_mlp_loss.py | 27 ------ topobenchmarkx/loss/loss.py | 38 +++----- .../nn/backbones/graph/graph_mlp.py | 6 +- .../nn/wrappers/graph/graph_mlp_wrapper.py | 28 ------ 9 files changed, 193 insertions(+), 83 deletions(-) create mode 100644 topobenchmarkx/loss/custom_losses/DatasetLoss.py create mode 100644 topobenchmarkx/loss/custom_losses/GraphMLPLoss.py delete mode 100644 topobenchmarkx/loss/custom_losses/graph_mlp_loss.py diff --git a/configs/loss/default.yaml b/configs/loss/default.yaml index 9ede46d6..af17d268 100644 --- a/configs/loss/default.yaml +++ b/configs/loss/default.yaml @@ -1,3 +1,5 @@ _target_: topobenchmarkx.loss.TBXLoss -task: ${dataset.parameters.task} -loss_type: ${dataset.parameters.loss_type} \ No newline at end of file +dataset_loss: + task: ${dataset.parameters.task} + loss_type: ${dataset.parameters.loss_type} +model_loss: ${oc.select:model.backbone.loss,null} \ No newline at end of file diff --git a/configs/model/graph/graph_mlp.yaml b/configs/model/graph/graph_mlp.yaml index d69dd747..36dab62f 100755 --- a/configs/model/graph/graph_mlp.yaml +++ b/configs/model/graph/graph_mlp.yaml @@ -1,6 +1,6 @@ _target_: topobenchmarkx.model.TBXModel -model_name: gat +model_name: GraphMLP model_domain: graph feature_encoder: @@ -16,6 +16,11 @@ backbone: hidden_channels: ${model.feature_encoder.out_channels} order: 2 dropout: 0.0 + loss: + _target_: topobenchmarkx.loss.custom_losses.GraphMLPLoss + r_adj_power: 2 + alpha: 1. + tau: 1. backbone_wrapper: _target_: topobenchmarkx.nn.wrappers.GraphMLPWrapper @@ -33,5 +38,7 @@ readout: task_level: ${dataset.parameters.task_level} pooling_type: sum + + # compile model for faster training with pytorch 2.0 compile: false diff --git a/topobenchmarkx/loss/custom_losses/DatasetLoss.py b/topobenchmarkx/loss/custom_losses/DatasetLoss.py new file mode 100644 index 00000000..0d05a9bc --- /dev/null +++ b/topobenchmarkx/loss/custom_losses/DatasetLoss.py @@ -0,0 +1,58 @@ +"""Loss module for the topobenchmarkx package.""" + +import torch +import torch_geometric + +from topobenchmarkx.loss.base import AbstractLoss + + +class DatasetLoss(AbstractLoss): + r"""Defines the default model loss for the given task. + + Parameters + ---------- + dataset_loss : dict + Dictionary containing the dataset loss information. + """ + + def __init__(self, dataset_loss): + super().__init__() + self.task = dataset_loss["task"] + self.loss_type = dataset_loss["loss_type"] + # Dataset loss + if self.task == "classification" and self.loss_type == "cross_entropy": + self.criterion = torch.nn.CrossEntropyLoss() + elif self.task == "regression" and self.loss_type == "mse": + self.criterion = torch.nn.MSELoss() + elif self.task == "regression" and self.loss_type == "mae": + self.criterion = torch.nn.L1Loss() + else: + raise Exception("Loss is not defined") + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(task={self.task}, loss_type={self.loss_type})" + + def forward(self, model_out: dict, batch: torch_geometric.data.Data): + r"""Forward pass of the loss function. + + Parameters + ---------- + model_out : dict + Dictionary containing the model output. + batch : torch_geometric.data.Data + Batch object containing the batched domain data. + + Returns + ------- + dict + Dictionary containing the model output with the loss. + """ + logits = model_out["logits"] + target = model_out["labels"] + + if self.task == "regression": + target = target.unsqueeze(1) + + dataset_loss = self.criterion(logits, target) + + return dataset_loss diff --git a/topobenchmarkx/loss/custom_losses/GraphMLPLoss.py b/topobenchmarkx/loss/custom_losses/GraphMLPLoss.py new file mode 100644 index 00000000..ed0646cd --- /dev/null +++ b/topobenchmarkx/loss/custom_losses/GraphMLPLoss.py @@ -0,0 +1,96 @@ +"""Graph MLP loss function.""" + +import torch +import torch_geometric + +from topobenchmarkx.loss.base import AbstractLoss + + +class GraphMLPLoss(AbstractLoss): + r"""Graph MLP loss function. + + Parameters + ---------- + r_adj_power : int, optional + Power of the adjacency matrix (default: 2). + alpha : float, optional + Alpha parameter (default: 1). + tau : float, optional + Temperature parameter (default: 1). + """ + + def __init__(self, r_adj_power=2, alpha=0.5, tau=1.0): + super().__init__() + self.r_adj_power = r_adj_power + self.alpha = alpha + self.tau = tau + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(r_adj_power={self.r_adj_power}, alpha={self.alpha}, tau={self.tau})" + + def get_power_adj(self, edge_index): + r"""Get the power of the adjacency matrix. + + Parameters + ---------- + edge_index : torch.Tensor + Edge index tensor. + + Returns + ------- + torch.Tensor + Power of the adjacency matrix. + """ + edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index) + adj = torch_geometric.utils.to_dense_adj(edge_index) + adj_power = adj.clone() + for _ in range(self.r_adj_power - 1): + adj_power = torch.matmul(adj_power, adj) + return adj_power + + def graph_mlp_contrast_loss(self, x_dis, adj_label): + """Graph MLP contrastive loss. + + Parameters + ---------- + x_dis : torch.Tensor + Distance matrix. + adj_label : torch.Tensor + Adjacency matrix. + + Returns + ------- + torch.Tensor + Contrastive loss. + """ + x_dis = torch.exp(self.tau * x_dis) + x_dis_sum = torch.sum(x_dis, 1) + x_dis_sum_pos = torch.sum(x_dis * adj_label, 1) + loss = -torch.log(x_dis_sum_pos * (x_dis_sum ** (-1)) + 1e-8).mean() + return loss + + def forward( + self, model_out: dict, batch: torch_geometric.data.Data + ) -> torch.Tensor: + r"""Forward pass of the loss function. + + Parameters + ---------- + model_out : dict + Dictionary containing the model output. + batch : torch_geometric.data.Data + Batch object containing the batched domain data. + + Returns + ------- + dict + Dictionary containing the model output with the loss. + """ + x_dis = model_out["x_dis"] + if x_dis is None: # Validation and test + return torch.tensor(0.0) + adj_label = self.get_power_adj(batch.edge_index) + graph_mlp_loss = self.alpha * self.graph_mlp_contrast_loss( + x_dis, adj_label + ) + return graph_mlp_loss diff --git a/topobenchmarkx/loss/custom_losses/__init__.py b/topobenchmarkx/loss/custom_losses/__init__.py index 164d2f8a..8bcdc7b2 100644 --- a/topobenchmarkx/loss/custom_losses/__init__.py +++ b/topobenchmarkx/loss/custom_losses/__init__.py @@ -1 +1,9 @@ """Init file for custom loss module.""" + +from .DatasetLoss import DatasetLoss +from .GraphMLPLoss import GraphMLPLoss + +__all__ = [ + "GraphMLPLoss", + "DatasetLoss", +] diff --git a/topobenchmarkx/loss/custom_losses/graph_mlp_loss.py b/topobenchmarkx/loss/custom_losses/graph_mlp_loss.py deleted file mode 100644 index b5fc8e14..00000000 --- a/topobenchmarkx/loss/custom_losses/graph_mlp_loss.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Graph MLP loss function.""" - -import torch - - -def graph_mlp_contrast_loss(x_dis, adj_label, tau=1): - """Graph MLP contrastive loss. - - Parameters - ---------- - x_dis : torch.Tensor - Distance matrix. - adj_label : torch.Tensor - Adjacency matrix. - tau : float, optional - Temperature parameter (default: 1). - - Returns - ------- - torch.Tensor - Contrastive loss. - """ - x_dis = torch.exp(tau * x_dis) - x_dis_sum = torch.sum(x_dis, 1) - x_dis_sum_pos = torch.sum(x_dis * adj_label, 1) - loss = -torch.log(x_dis_sum_pos * (x_dis_sum ** (-1)) + 1e-8).mean() - return loss diff --git a/topobenchmarkx/loss/loss.py b/topobenchmarkx/loss/loss.py index 3e8c33d8..5198bde4 100644 --- a/topobenchmarkx/loss/loss.py +++ b/topobenchmarkx/loss/loss.py @@ -4,6 +4,7 @@ import torch_geometric from topobenchmarkx.loss.base import AbstractLoss +from topobenchmarkx.loss.custom_losses import DatasetLoss class TBXLoss(AbstractLoss): @@ -11,27 +12,20 @@ class TBXLoss(AbstractLoss): Parameters ---------- - task : str - Task type, either "classification" or "regression". - loss_type : str, optional - Loss type, either "cross_entropy", "mse", or "mae" (default: None). + dataset_loss : dict + Dictionary containing the dataset loss information. + model_loss : AbstractLoss, optional + Custom model loss to be used. """ - def __init__(self, task, loss_type=None): + def __init__(self, dataset_loss, model_loss=None): super().__init__() - self.task = task - if task == "classification" and loss_type == "cross_entropy": - self.criterion = torch.nn.CrossEntropyLoss() - - elif task == "regression" and loss_type == "mse": - self.criterion = torch.nn.MSELoss() - - elif task == "regression" and loss_type == "mae": - self.criterion = torch.nn.L1Loss() - - else: - raise Exception("Loss is not defined") - self.loss_type = loss_type + self.losses = [] + # Dataset loss + self.losses.append(DatasetLoss(dataset_loss)) + # Model loss + if model_loss is not None: + self.losses.append(model_loss) def __repr__(self) -> str: return f"{self.__class__.__name__}(task={self.task}, loss_type={self.loss_type})" @@ -51,12 +45,8 @@ def forward(self, model_out: dict, batch: torch_geometric.data.Data): dict Dictionary containing the model output with the loss. """ - logits = model_out["logits"] - target = model_out["labels"] - - if self.task == "regression": - target = target.unsqueeze(1) + losses = [loss(model_out, batch) for loss in self.losses] - model_out["loss"] = self.criterion(logits, target) + model_out["loss"] = torch.stack(losses).sum() return model_out diff --git a/topobenchmarkx/nn/backbones/graph/graph_mlp.py b/topobenchmarkx/nn/backbones/graph/graph_mlp.py index 0062c535..a8edff79 100644 --- a/topobenchmarkx/nn/backbones/graph/graph_mlp.py +++ b/topobenchmarkx/nn/backbones/graph/graph_mlp.py @@ -18,9 +18,13 @@ class GraphMLP(nn.Module): To compute order-th power of adj matrix (default: 1). dropout : float, optional Dropout rate (default: 0.0). + **kwargs + Additional arguments. """ - def __init__(self, in_channels, hidden_channels, order=1, dropout=0.0): + def __init__( + self, in_channels, hidden_channels, order=1, dropout=0.0, **kwargs + ): super().__init__() self.nhid = hidden_channels self.order = order diff --git a/topobenchmarkx/nn/wrappers/graph/graph_mlp_wrapper.py b/topobenchmarkx/nn/wrappers/graph/graph_mlp_wrapper.py index d28bae0d..8896da75 100644 --- a/topobenchmarkx/nn/wrappers/graph/graph_mlp_wrapper.py +++ b/topobenchmarkx/nn/wrappers/graph/graph_mlp_wrapper.py @@ -1,8 +1,5 @@ """Wrapper for the GNN models.""" -import torch -import torch_geometric - from topobenchmarkx.nn.wrappers.base import AbstractWrapper @@ -13,28 +10,6 @@ class GraphMLPWrapper(AbstractWrapper): the embeddings of the cells of rank 0. """ - def get_power_adj(self, edge_index, order=1): - r"""Get the power of the adjacency matrix. - - Parameters - ---------- - edge_index : torch.Tensor - Edge index tensor. - order : int, optional - Order of the adjacency matrix (default: 1). - - Returns - ------- - torch.Tensor - Power of the adjacency matrix. - """ - edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index) - adj = torch_geometric.utils.to_dense_adj(edge_index) - adj_power = adj.clone() - for _ in range(order - 1): - adj_power = torch.matmul(adj_power, adj) - return adj_power - def forward(self, batch): r"""Forward pass for the GNN wrapper. @@ -53,8 +28,5 @@ def forward(self, batch): model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = x_0 model_out["x_dis"] = x_dis - model_out["adj_label"] = self.get_power_adj( - batch.edge_index, self.backbone.order - ) return model_out From ee8c936d0d480e1acb555864cf82add79214c151 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Thu, 21 Nov 2024 11:33:27 -0800 Subject: [PATCH 4/9] Enabling custom losses in all models' modules --- configs/loss/default.yaml | 5 ++++- configs/model/graph/graph_mlp.yaml | 4 ++-- .../{custom_losses => dataset}/DatasetLoss.py | 0 topobenchmarkx/loss/dataset/__init__.py | 7 +++++++ topobenchmarkx/loss/loss.py | 15 ++++++++------- .../loss/{custom_losses => model}/GraphMLPLoss.py | 10 +++++----- .../loss/{custom_losses => model}/__init__.py | 2 -- topobenchmarkx/model/model.py | 4 ++-- 8 files changed, 28 insertions(+), 19 deletions(-) rename topobenchmarkx/loss/{custom_losses => dataset}/DatasetLoss.py (100%) create mode 100644 topobenchmarkx/loss/dataset/__init__.py rename topobenchmarkx/loss/{custom_losses => model}/GraphMLPLoss.py (91%) rename topobenchmarkx/loss/{custom_losses => model}/__init__.py (67%) diff --git a/configs/loss/default.yaml b/configs/loss/default.yaml index af17d268..ca1742d7 100644 --- a/configs/loss/default.yaml +++ b/configs/loss/default.yaml @@ -2,4 +2,7 @@ _target_: topobenchmarkx.loss.TBXLoss dataset_loss: task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} -model_loss: ${oc.select:model.backbone.loss,null} \ No newline at end of file +modules_losses: # Collect model losses + feature_encoder: ${oc.select:model.feature_encoder.loss,null} + backbone: ${oc.select:model.backbone.loss,null} + readout: ${oc.select:model.readout.loss,null} \ No newline at end of file diff --git a/configs/model/graph/graph_mlp.yaml b/configs/model/graph/graph_mlp.yaml index 36dab62f..34fe5072 100755 --- a/configs/model/graph/graph_mlp.yaml +++ b/configs/model/graph/graph_mlp.yaml @@ -17,10 +17,10 @@ backbone: order: 2 dropout: 0.0 loss: - _target_: topobenchmarkx.loss.custom_losses.GraphMLPLoss + _target_: topobenchmarkx.loss.model.GraphMLPLoss r_adj_power: 2 - alpha: 1. tau: 1. + loss_weight: 0.5 backbone_wrapper: _target_: topobenchmarkx.nn.wrappers.GraphMLPWrapper diff --git a/topobenchmarkx/loss/custom_losses/DatasetLoss.py b/topobenchmarkx/loss/dataset/DatasetLoss.py similarity index 100% rename from topobenchmarkx/loss/custom_losses/DatasetLoss.py rename to topobenchmarkx/loss/dataset/DatasetLoss.py diff --git a/topobenchmarkx/loss/dataset/__init__.py b/topobenchmarkx/loss/dataset/__init__.py new file mode 100644 index 00000000..3cfc3b83 --- /dev/null +++ b/topobenchmarkx/loss/dataset/__init__.py @@ -0,0 +1,7 @@ +"""Init file for custom loss module.""" + +from .DatasetLoss import DatasetLoss + +__all__ = [ + "DatasetLoss", +] diff --git a/topobenchmarkx/loss/loss.py b/topobenchmarkx/loss/loss.py index 5198bde4..482d5807 100644 --- a/topobenchmarkx/loss/loss.py +++ b/topobenchmarkx/loss/loss.py @@ -4,7 +4,7 @@ import torch_geometric from topobenchmarkx.loss.base import AbstractLoss -from topobenchmarkx.loss.custom_losses import DatasetLoss +from topobenchmarkx.loss.dataset import DatasetLoss class TBXLoss(AbstractLoss): @@ -14,18 +14,19 @@ class TBXLoss(AbstractLoss): ---------- dataset_loss : dict Dictionary containing the dataset loss information. - model_loss : AbstractLoss, optional - Custom model loss to be used. + modules_losses : AbstractLoss, optional + Custom modules' losses to be used. """ - def __init__(self, dataset_loss, model_loss=None): + def __init__(self, dataset_loss, modules_losses=()): super().__init__() self.losses = [] # Dataset loss self.losses.append(DatasetLoss(dataset_loss)) - # Model loss - if model_loss is not None: - self.losses.append(model_loss) + # Model losses + self.losses.extend( + [loss for loss in modules_losses.values() if loss is not None] + ) def __repr__(self) -> str: return f"{self.__class__.__name__}(task={self.task}, loss_type={self.loss_type})" diff --git a/topobenchmarkx/loss/custom_losses/GraphMLPLoss.py b/topobenchmarkx/loss/model/GraphMLPLoss.py similarity index 91% rename from topobenchmarkx/loss/custom_losses/GraphMLPLoss.py rename to topobenchmarkx/loss/model/GraphMLPLoss.py index ed0646cd..c1949894 100644 --- a/topobenchmarkx/loss/custom_losses/GraphMLPLoss.py +++ b/topobenchmarkx/loss/model/GraphMLPLoss.py @@ -13,17 +13,17 @@ class GraphMLPLoss(AbstractLoss): ---------- r_adj_power : int, optional Power of the adjacency matrix (default: 2). - alpha : float, optional - Alpha parameter (default: 1). tau : float, optional Temperature parameter (default: 1). + loss_weight : float, optional + Loss weight (default: 0.5). """ - def __init__(self, r_adj_power=2, alpha=0.5, tau=1.0): + def __init__(self, r_adj_power=2, tau=1.0, loss_weight=0.5): super().__init__() self.r_adj_power = r_adj_power - self.alpha = alpha self.tau = tau + self.loss_weight = loss_weight def __repr__(self) -> str: return f"{self.__class__.__name__}(r_adj_power={self.r_adj_power}, alpha={self.alpha}, tau={self.tau})" @@ -90,7 +90,7 @@ def forward( if x_dis is None: # Validation and test return torch.tensor(0.0) adj_label = self.get_power_adj(batch.edge_index) - graph_mlp_loss = self.alpha * self.graph_mlp_contrast_loss( + graph_mlp_loss = self.loss_weight * self.graph_mlp_contrast_loss( x_dis, adj_label ) return graph_mlp_loss diff --git a/topobenchmarkx/loss/custom_losses/__init__.py b/topobenchmarkx/loss/model/__init__.py similarity index 67% rename from topobenchmarkx/loss/custom_losses/__init__.py rename to topobenchmarkx/loss/model/__init__.py index 8bcdc7b2..4943a8d9 100644 --- a/topobenchmarkx/loss/custom_losses/__init__.py +++ b/topobenchmarkx/loss/model/__init__.py @@ -1,9 +1,7 @@ """Init file for custom loss module.""" -from .DatasetLoss import DatasetLoss from .GraphMLPLoss import GraphMLPLoss __all__ = [ "GraphMLPLoss", - "DatasetLoss", ] diff --git a/topobenchmarkx/model/model.py b/topobenchmarkx/model/model.py index bf89129e..189c794a 100755 --- a/topobenchmarkx/model/model.py +++ b/topobenchmarkx/model/model.py @@ -106,10 +106,10 @@ def model_step(self, batch: Data) -> dict: Dictionary containing the model output and the loss. """ # Feature Encoder - batch = self.feature_encoder(batch) + model_out = self.feature_encoder(batch) # Domain model - model_out = self.forward(batch) + model_out = self.forward(model_out) # Readout if self.readout is not None: From 9f349373009f597f3c0a2eb7cc42dc958db4f45e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Thu, 21 Nov 2024 13:34:13 -0800 Subject: [PATCH 5/9] Add GraphML tests --- configs/run.yaml | 2 +- test/data/load/test_datasetloaders.py | 2 +- test/nn/backbones/graph/__init__.py | 1 + test/nn/backbones/graph/test_graphmlp.py | 17 +++++++++++++++++ 4 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 test/nn/backbones/graph/__init__.py create mode 100644 test/nn/backbones/graph/test_graphmlp.py diff --git a/configs/run.yaml b/configs/run.yaml index d7192338..8ec582f4 100755 --- a/configs/run.yaml +++ b/configs/run.yaml @@ -5,7 +5,7 @@ defaults: - _self_ - dataset: graph/cocitation_cora - - model: graph/graph_mlp + - model: cell/topotune - transforms: ${get_default_transform:${dataset},${model}} #no_transform - optimizer: default - loss: default diff --git a/test/data/load/test_datasetloaders.py b/test/data/load/test_datasetloaders.py index 0c387d1b..82790a94 100644 --- a/test/data/load/test_datasetloaders.py +++ b/test/data/load/test_datasetloaders.py @@ -75,7 +75,7 @@ def _load_dataset(self, data_domain: str, config_file: str) -> Tuple[Any, Dict]: print('Current config file: ', config_file) parameters = hydra.compose( config_name="run.yaml", - overrides=[f"dataset={data_domain}/{config_file}"], + overrides=[f"dataset={data_domain}/{config_file}", f"model=graph/gat"], return_hydra_config=True ) diff --git a/test/nn/backbones/graph/__init__.py b/test/nn/backbones/graph/__init__.py new file mode 100644 index 00000000..2bbda54e --- /dev/null +++ b/test/nn/backbones/graph/__init__.py @@ -0,0 +1 @@ +"""Init file for custom graph models.""" \ No newline at end of file diff --git a/test/nn/backbones/graph/test_graphmlp.py b/test/nn/backbones/graph/test_graphmlp.py new file mode 100644 index 00000000..d5d4edf4 --- /dev/null +++ b/test/nn/backbones/graph/test_graphmlp.py @@ -0,0 +1,17 @@ +"""Unit tests for GraphMLP.""" + +from topobenchmarkx.nn.backbones.graph.graph_mlp import GraphMLP + +def testGraphMLP(random_graph_input): + """ Unit test for GraphMLP. + + Parameters + ---------- + random_graph_input : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]] + A tuple of input tensors for testing EDGNN. + """ + x, x_1, x_2, edges_1, edges_2 = random_graph_input + model = GraphMLP(x.shape[1], x.shape[1]) + out = model(x) + assert out[0].shape == x.shape + assert list(out[1].shape) == [8,8] From 564d5e25ba6ed898807edbeb2ec7b3b6ac5fc0a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Thu, 21 Nov 2024 14:26:05 -0800 Subject: [PATCH 6/9] Optional residual connections in Wrappers --- topobenchmarkx/nn/wrappers/base.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/topobenchmarkx/nn/wrappers/base.py b/topobenchmarkx/nn/wrappers/base.py index 39f813d4..a9bf1726 100755 --- a/topobenchmarkx/nn/wrappers/base.py +++ b/topobenchmarkx/nn/wrappers/base.py @@ -24,6 +24,7 @@ def __init__(self, backbone, **kwargs): self.backbone = backbone out_channels = kwargs["out_channels"] self.dimensions = range(kwargs["num_cell_dimensions"]) + self.residual_connections = kwargs.get("residual_connections", True) for i in self.dimensions: setattr( @@ -33,7 +34,7 @@ def __init__(self, backbone, **kwargs): ) def __repr__(self): - return f"{self.__class__.__name__}(backbone={self.backbone}, out_channels={self.backbone.out_channels}, dimensions={self.dimensions})" + return f"{self.__class__.__name__}(backbone={self.backbone}, out_channels={self.backbone.out_channels}, dimensions={self.dimensions}, residual_connections={self.residual_connections})" def __call__(self, batch): r"""Forward pass for the model. @@ -51,7 +52,11 @@ def __call__(self, batch): Dictionary containing the model output. """ model_out = self.forward(batch) - model_out = self.residual_connection(model_out=model_out, batch=batch) + model_out = ( + self.residual_connection(model_out=model_out, batch=batch) + if self.residual_connections + else model_out + ) return model_out def residual_connection(self, model_out, batch): From b4dce8263ea5f88354b1dcdfecfef5127cf5de09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Thu, 21 Nov 2024 16:56:04 -0800 Subject: [PATCH 7/9] Add GraphMLP tests --- test/nn/backbones/graph/test_graphmlp.py | 21 +++++++++++++++---- .../nn/backbones/graph/graph_mlp.py | 4 ++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/test/nn/backbones/graph/test_graphmlp.py b/test/nn/backbones/graph/test_graphmlp.py index d5d4edf4..e7682f10 100644 --- a/test/nn/backbones/graph/test_graphmlp.py +++ b/test/nn/backbones/graph/test_graphmlp.py @@ -1,6 +1,10 @@ """Unit tests for GraphMLP.""" -from topobenchmarkx.nn.backbones.graph.graph_mlp import GraphMLP +import torch +import torch_geometric +from topobenchmarkx.nn.backbones.graph import GraphMLP +from topobenchmarkx.nn.wrappers.graph import GraphMLPWrapper +from topobenchmarkx.loss.model import GraphMLPLoss def testGraphMLP(random_graph_input): """ Unit test for GraphMLP. @@ -11,7 +15,16 @@ def testGraphMLP(random_graph_input): A tuple of input tensors for testing EDGNN. """ x, x_1, x_2, edges_1, edges_2 = random_graph_input + batch = torch_geometric.data.Data(x_0=x, y=x, edge_index=edges_1, batch_0=torch.zeros(x.shape[0], dtype=torch.long)) model = GraphMLP(x.shape[1], x.shape[1]) - out = model(x) - assert out[0].shape == x.shape - assert list(out[1].shape) == [8,8] + wrapper = GraphMLPWrapper(model, **{"out_channels": x.shape[1], "num_cell_dimensions": 1}) + loss_fn = GraphMLPLoss() + + model_out = wrapper(batch) + assert model_out["x_0"].shape == x.shape + assert list(model_out["x_dis"].shape) == [8,8] + + loss = loss_fn(model_out, batch) + assert loss.item() >= 0 + + diff --git a/topobenchmarkx/nn/backbones/graph/graph_mlp.py b/topobenchmarkx/nn/backbones/graph/graph_mlp.py index a8edff79..8065a35f 100644 --- a/topobenchmarkx/nn/backbones/graph/graph_mlp.py +++ b/topobenchmarkx/nn/backbones/graph/graph_mlp.py @@ -26,9 +26,9 @@ def __init__( self, in_channels, hidden_channels, order=1, dropout=0.0, **kwargs ): super().__init__() - self.nhid = hidden_channels + self.out_channels = hidden_channels self.order = order - self.mlp = Mlp(in_channels, self.nhid, dropout) + self.mlp = Mlp(in_channels, self.out_channels, dropout) def forward(self, x): """Forward pass. From 2c27813eada12614f00c48ea75d7f7f56116a5b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Thu, 21 Nov 2024 19:32:49 -0800 Subject: [PATCH 8/9] Fix loss config in tutorials --- topobenchmarkx/loss/loss.py | 2 +- tutorials/tutorial_dataset.ipynb | 12 +++++++++--- tutorials/tutorial_lifting.ipynb | 12 +++++++++--- tutorials/tutorial_model.ipynb | 12 +++++++++--- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/topobenchmarkx/loss/loss.py b/topobenchmarkx/loss/loss.py index 482d5807..95c68b23 100644 --- a/topobenchmarkx/loss/loss.py +++ b/topobenchmarkx/loss/loss.py @@ -18,7 +18,7 @@ class TBXLoss(AbstractLoss): Custom modules' losses to be used. """ - def __init__(self, dataset_loss, modules_losses=()): + def __init__(self, dataset_loss, modules_losses={}): # noqa: B006 super().__init__() self.losses = [] # Dataset loss diff --git a/tutorials/tutorial_dataset.ipynb b/tutorials/tutorial_dataset.ipynb index 6ab70534..fa2be8d2 100644 --- a/tutorials/tutorial_dataset.ipynb +++ b/tutorials/tutorial_dataset.ipynb @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -116,7 +116,13 @@ " \"pooling_type\": \"sum\",\n", "}\n", "\n", - "loss_config = {\"task\": \"classification\", \"loss_type\": \"cross_entropy\"}\n", + "loss_config = {\n", + " \"dataset_loss\": \n", + " {\n", + " \"task\": \"classification\", \n", + " \"loss_type\": \"cross_entropy\"\n", + " }\n", + "}\n", "\n", "evaluator_config = {\"task\": \"classification\",\n", " \"num_classes\": out_channels,\n", @@ -425,7 +431,7 @@ ], "metadata": { "kernelspec": { - "display_name": "tb", + "display_name": "topox", "language": "python", "name": "python3" }, diff --git a/tutorials/tutorial_lifting.ipynb b/tutorials/tutorial_lifting.ipynb index 1fa6bc4a..9e3fa0ad 100644 --- a/tutorials/tutorial_lifting.ipynb +++ b/tutorials/tutorial_lifting.ipynb @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -137,7 +137,13 @@ " \"pooling_type\": \"sum\",\n", "}\n", "\n", - "loss_config = {\"task\": \"classification\", \"loss_type\": \"cross_entropy\"}\n", + "loss_config = {\n", + " \"dataset_loss\": \n", + " {\n", + " \"task\": \"classification\", \n", + " \"loss_type\": \"cross_entropy\"\n", + " }\n", + "}\n", "\n", "evaluator_config = {\"task\": \"classification\",\n", " \"num_classes\": out_channels,\n", @@ -514,7 +520,7 @@ ], "metadata": { "kernelspec": { - "display_name": "tb", + "display_name": "topox", "language": "python", "name": "python3" }, diff --git a/tutorials/tutorial_model.ipynb b/tutorials/tutorial_model.ipynb index 97e62498..a7289f6b 100644 --- a/tutorials/tutorial_model.ipynb +++ b/tutorials/tutorial_model.ipynb @@ -82,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -119,7 +119,13 @@ " \"pooling_type\": \"sum\",\n", "}\n", "\n", - "loss_config = {\"task\": \"classification\", \"loss_type\": \"cross_entropy\"}\n", + "loss_config = {\n", + " \"dataset_loss\": \n", + " {\n", + " \"task\": \"classification\", \n", + " \"loss_type\": \"cross_entropy\"\n", + " }\n", + "}\n", "\n", "evaluator_config = {\"task\": \"classification\",\n", " \"num_classes\": out_channels,\n", @@ -470,7 +476,7 @@ ], "metadata": { "kernelspec": { - "display_name": "tb", + "display_name": "topox", "language": "python", "name": "python3" }, From 8492994fea3395231f293c06363c0c5dcaa2e85f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Bern=C3=A1rdez?= Date: Thu, 21 Nov 2024 20:15:21 -0800 Subject: [PATCH 9/9] Improve test coverage --- test/loss/test_dataset_loss.py | 34 +++++++++++++++++++++++ test/nn/backbones/graph/test_graphmlp.py | 7 +++++ topobenchmarkx/loss/model/GraphMLPLoss.py | 2 +- 3 files changed, 42 insertions(+), 1 deletion(-) create mode 100644 test/loss/test_dataset_loss.py diff --git a/test/loss/test_dataset_loss.py b/test/loss/test_dataset_loss.py new file mode 100644 index 00000000..2097eba6 --- /dev/null +++ b/test/loss/test_dataset_loss.py @@ -0,0 +1,34 @@ +""" Test the TBXEvaluator class.""" +import pytest +import torch +import torch_geometric + +from topobenchmarkx.loss.dataset import DatasetLoss + +class TestDatasetLoss: + """ Test the TBXEvaluator class.""" + + def setup_method(self): + """ Setup the test.""" + dataset_loss = {"task": "classification", "loss_type": "cross_entropy"} + self.dataset1 = DatasetLoss(dataset_loss) + dataset_loss = {"task": "regression", "loss_type": "mse"} + self.dataset2 = DatasetLoss(dataset_loss) + dataset_loss = {"task": "regression", "loss_type": "mae"} + self.dataset3 = DatasetLoss(dataset_loss) + dataset_loss = {"task": "wrong", "loss_type": "wrong"} + with pytest.raises(Exception): + DatasetLoss(dataset_loss) + repr = self.dataset1.__repr__() + assert repr == "DatasetLoss(task=classification, loss_type=cross_entropy)" + + def test_forward(self): + """ Test the forward method.""" + batch = torch_geometric.data.Data() + model_out = {"logits": torch.tensor([0.1, 0.2, 0.3]), "labels": torch.tensor([0.1, 0.2, 0.3])} + out = self.dataset1.forward(model_out, batch) + assert out.item() >= 0 + model_out = {"logits": torch.tensor([0.1, 0.2, 0.3]), "labels": torch.tensor([0.1, 0.2, 0.3])} + out = self.dataset3.forward(model_out, batch) + assert out.item() >= 0 + \ No newline at end of file diff --git a/test/nn/backbones/graph/test_graphmlp.py b/test/nn/backbones/graph/test_graphmlp.py index e7682f10..5810414d 100644 --- a/test/nn/backbones/graph/test_graphmlp.py +++ b/test/nn/backbones/graph/test_graphmlp.py @@ -20,6 +20,9 @@ def testGraphMLP(random_graph_input): wrapper = GraphMLPWrapper(model, **{"out_channels": x.shape[1], "num_cell_dimensions": 1}) loss_fn = GraphMLPLoss() + _ = wrapper.__repr__() + _ = loss_fn.__repr__() + model_out = wrapper(batch) assert model_out["x_0"].shape == x.shape assert list(model_out["x_dis"].shape) == [8,8] @@ -27,4 +30,8 @@ def testGraphMLP(random_graph_input): loss = loss_fn(model_out, batch) assert loss.item() >= 0 + model_out["x_dis"] = None + loss = loss_fn(model_out, batch) + assert loss == torch.tensor(0.0) + diff --git a/topobenchmarkx/loss/model/GraphMLPLoss.py b/topobenchmarkx/loss/model/GraphMLPLoss.py index c1949894..3ee35575 100644 --- a/topobenchmarkx/loss/model/GraphMLPLoss.py +++ b/topobenchmarkx/loss/model/GraphMLPLoss.py @@ -26,7 +26,7 @@ def __init__(self, r_adj_power=2, tau=1.0, loss_weight=0.5): self.loss_weight = loss_weight def __repr__(self) -> str: - return f"{self.__class__.__name__}(r_adj_power={self.r_adj_power}, alpha={self.alpha}, tau={self.tau})" + return f"{self.__class__.__name__}(r_adj_power={self.r_adj_power}, tau={self.tau}, loss_weight={self.loss_weight})" def get_power_adj(self, edge_index): r"""Get the power of the adjacency matrix.