Skip to content

Commit

Permalink
Merge pull request #112 from geometric-intelligence/loss
Browse files Browse the repository at this point in the history
Loss
  • Loading branch information
gbg141 authored Nov 22, 2024
2 parents a21a1f7 + 8492994 commit 6b5c349
Show file tree
Hide file tree
Showing 23 changed files with 516 additions and 44 deletions.
9 changes: 7 additions & 2 deletions configs/loss/default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
_target_: topobenchmarkx.loss.TBXLoss
task: ${dataset.parameters.task}
loss_type: ${dataset.parameters.loss_type}
dataset_loss:
task: ${dataset.parameters.task}
loss_type: ${dataset.parameters.loss_type}
modules_losses: # Collect model losses
feature_encoder: ${oc.select:model.feature_encoder.loss,null}
backbone: ${oc.select:model.backbone.loss,null}
readout: ${oc.select:model.readout.loss,null}
44 changes: 44 additions & 0 deletions configs/model/graph/graph_mlp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
_target_: topobenchmarkx.model.TBXModel

model_name: GraphMLP
model_domain: graph

feature_encoder:
_target_: topobenchmarkx.nn.encoders.${model.feature_encoder.encoder_name}
encoder_name: AllCellFeatureEncoder
in_channels: ${infer_in_channels:${dataset},${oc.select:transforms,null}}
out_channels: 32
proj_dropout: 0.0

backbone:
_target_: topobenchmarkx.nn.backbones.GraphMLP
in_channels: ${model.feature_encoder.out_channels}
hidden_channels: ${model.feature_encoder.out_channels}
order: 2
dropout: 0.0
loss:
_target_: topobenchmarkx.loss.model.GraphMLPLoss
r_adj_power: 2
tau: 1.
loss_weight: 0.5

backbone_wrapper:
_target_: topobenchmarkx.nn.wrappers.GraphMLPWrapper
_partial_: true
wrapper_name: GraphMLPWrapper
out_channels: ${model.feature_encoder.out_channels}
num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}}

readout:
_target_: topobenchmarkx.nn.readouts.${model.readout.readout_name}
readout_name: NoReadOut # Use <NoReadOut> in case readout is not needed Options: PropagateSignalDown
num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider
hidden_dim: ${model.feature_encoder.out_channels}
out_channels: ${dataset.parameters.num_classes}
task_level: ${dataset.parameters.task_level}
pooling_type: sum



# compile model for faster training with pytorch 2.0
compile: false
4 changes: 2 additions & 2 deletions configs/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
defaults:
- _self_
- dataset: graph/cocitation_cora
- model: hypergraph/unignn2
- model: cell/topotune
- transforms: ${get_default_transform:${dataset},${model}} #no_transform
- optimizer: default
- loss: default
- evaluator: default
- callbacks: default
- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
- trainer: cpu
- paths: default
- extras: default
- hydra: default
Expand Down
2 changes: 1 addition & 1 deletion test/data/load/test_datasetloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _load_dataset(self, data_domain: str, config_file: str) -> Tuple[Any, Dict]:
print('Current config file: ', config_file)
parameters = hydra.compose(
config_name="run.yaml",
overrides=[f"dataset={data_domain}/{config_file}"],
overrides=[f"dataset={data_domain}/{config_file}", f"model=graph/gat"],
return_hydra_config=True

)
Expand Down
34 changes: 34 additions & 0 deletions test/loss/test_dataset_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
""" Test the TBXEvaluator class."""
import pytest
import torch
import torch_geometric

from topobenchmarkx.loss.dataset import DatasetLoss

class TestDatasetLoss:
""" Test the TBXEvaluator class."""

def setup_method(self):
""" Setup the test."""
dataset_loss = {"task": "classification", "loss_type": "cross_entropy"}
self.dataset1 = DatasetLoss(dataset_loss)
dataset_loss = {"task": "regression", "loss_type": "mse"}
self.dataset2 = DatasetLoss(dataset_loss)
dataset_loss = {"task": "regression", "loss_type": "mae"}
self.dataset3 = DatasetLoss(dataset_loss)
dataset_loss = {"task": "wrong", "loss_type": "wrong"}
with pytest.raises(Exception):
DatasetLoss(dataset_loss)
repr = self.dataset1.__repr__()
assert repr == "DatasetLoss(task=classification, loss_type=cross_entropy)"

def test_forward(self):
""" Test the forward method."""
batch = torch_geometric.data.Data()
model_out = {"logits": torch.tensor([0.1, 0.2, 0.3]), "labels": torch.tensor([0.1, 0.2, 0.3])}
out = self.dataset1.forward(model_out, batch)
assert out.item() >= 0
model_out = {"logits": torch.tensor([0.1, 0.2, 0.3]), "labels": torch.tensor([0.1, 0.2, 0.3])}
out = self.dataset3.forward(model_out, batch)
assert out.item() >= 0

1 change: 1 addition & 0 deletions test/nn/backbones/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Init file for custom graph models."""
37 changes: 37 additions & 0 deletions test/nn/backbones/graph/test_graphmlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Unit tests for GraphMLP."""

import torch
import torch_geometric
from topobenchmarkx.nn.backbones.graph import GraphMLP
from topobenchmarkx.nn.wrappers.graph import GraphMLPWrapper
from topobenchmarkx.loss.model import GraphMLPLoss

def testGraphMLP(random_graph_input):
""" Unit test for GraphMLP.
Parameters
----------
random_graph_input : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]
A tuple of input tensors for testing EDGNN.
"""
x, x_1, x_2, edges_1, edges_2 = random_graph_input
batch = torch_geometric.data.Data(x_0=x, y=x, edge_index=edges_1, batch_0=torch.zeros(x.shape[0], dtype=torch.long))
model = GraphMLP(x.shape[1], x.shape[1])
wrapper = GraphMLPWrapper(model, **{"out_channels": x.shape[1], "num_cell_dimensions": 1})
loss_fn = GraphMLPLoss()

_ = wrapper.__repr__()
_ = loss_fn.__repr__()

model_out = wrapper(batch)
assert model_out["x_0"].shape == x.shape
assert list(model_out["x_dis"].shape) == [8,8]

loss = loss_fn(model_out, batch)
assert loss.item() >= 0

model_out["x_dis"] = None
loss = loss_fn(model_out, batch)
assert loss == torch.tensor(0.0)


58 changes: 58 additions & 0 deletions topobenchmarkx/loss/dataset/DatasetLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Loss module for the topobenchmarkx package."""

import torch
import torch_geometric

from topobenchmarkx.loss.base import AbstractLoss


class DatasetLoss(AbstractLoss):
r"""Defines the default model loss for the given task.
Parameters
----------
dataset_loss : dict
Dictionary containing the dataset loss information.
"""

def __init__(self, dataset_loss):
super().__init__()
self.task = dataset_loss["task"]
self.loss_type = dataset_loss["loss_type"]
# Dataset loss
if self.task == "classification" and self.loss_type == "cross_entropy":
self.criterion = torch.nn.CrossEntropyLoss()
elif self.task == "regression" and self.loss_type == "mse":
self.criterion = torch.nn.MSELoss()
elif self.task == "regression" and self.loss_type == "mae":
self.criterion = torch.nn.L1Loss()
else:
raise Exception("Loss is not defined")

def __repr__(self) -> str:
return f"{self.__class__.__name__}(task={self.task}, loss_type={self.loss_type})"

def forward(self, model_out: dict, batch: torch_geometric.data.Data):
r"""Forward pass of the loss function.
Parameters
----------
model_out : dict
Dictionary containing the model output.
batch : torch_geometric.data.Data
Batch object containing the batched domain data.
Returns
-------
dict
Dictionary containing the model output with the loss.
"""
logits = model_out["logits"]
target = model_out["labels"]

if self.task == "regression":
target = target.unsqueeze(1)

dataset_loss = self.criterion(logits, target)

return dataset_loss
7 changes: 7 additions & 0 deletions topobenchmarkx/loss/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Init file for custom loss module."""

from .DatasetLoss import DatasetLoss

__all__ = [
"DatasetLoss",
]
39 changes: 15 additions & 24 deletions topobenchmarkx/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,29 @@
import torch_geometric

from topobenchmarkx.loss.base import AbstractLoss
from topobenchmarkx.loss.dataset import DatasetLoss


class TBXLoss(AbstractLoss):
r"""Defines the default model loss for the given task.
Parameters
----------
task : str
Task type, either "classification" or "regression".
loss_type : str, optional
Loss type, either "cross_entropy", "mse", or "mae" (default: None).
dataset_loss : dict
Dictionary containing the dataset loss information.
modules_losses : AbstractLoss, optional
Custom modules' losses to be used.
"""

def __init__(self, task, loss_type=None):
def __init__(self, dataset_loss, modules_losses={}): # noqa: B006
super().__init__()
self.task = task
if task == "classification" and loss_type == "cross_entropy":
self.criterion = torch.nn.CrossEntropyLoss()

elif task == "regression" and loss_type == "mse":
self.criterion = torch.nn.MSELoss()

elif task == "regression" and loss_type == "mae":
self.criterion = torch.nn.L1Loss()

else:
raise Exception("Loss is not defined")
self.loss_type = loss_type
self.losses = []
# Dataset loss
self.losses.append(DatasetLoss(dataset_loss))
# Model losses
self.losses.extend(
[loss for loss in modules_losses.values() if loss is not None]
)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(task={self.task}, loss_type={self.loss_type})"
Expand All @@ -51,12 +46,8 @@ def forward(self, model_out: dict, batch: torch_geometric.data.Data):
dict
Dictionary containing the model output with the loss.
"""
logits = model_out["logits"]
target = model_out["labels"]

if self.task == "regression":
target = target.unsqueeze(1)
losses = [loss(model_out, batch) for loss in self.losses]

model_out["loss"] = self.criterion(logits, target)
model_out["loss"] = torch.stack(losses).sum()

return model_out
96 changes: 96 additions & 0 deletions topobenchmarkx/loss/model/GraphMLPLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Graph MLP loss function."""

import torch
import torch_geometric

from topobenchmarkx.loss.base import AbstractLoss


class GraphMLPLoss(AbstractLoss):
r"""Graph MLP loss function.
Parameters
----------
r_adj_power : int, optional
Power of the adjacency matrix (default: 2).
tau : float, optional
Temperature parameter (default: 1).
loss_weight : float, optional
Loss weight (default: 0.5).
"""

def __init__(self, r_adj_power=2, tau=1.0, loss_weight=0.5):
super().__init__()
self.r_adj_power = r_adj_power
self.tau = tau
self.loss_weight = loss_weight

def __repr__(self) -> str:
return f"{self.__class__.__name__}(r_adj_power={self.r_adj_power}, tau={self.tau}, loss_weight={self.loss_weight})"

def get_power_adj(self, edge_index):
r"""Get the power of the adjacency matrix.
Parameters
----------
edge_index : torch.Tensor
Edge index tensor.
Returns
-------
torch.Tensor
Power of the adjacency matrix.
"""
edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index)
adj = torch_geometric.utils.to_dense_adj(edge_index)
adj_power = adj.clone()
for _ in range(self.r_adj_power - 1):
adj_power = torch.matmul(adj_power, adj)
return adj_power

def graph_mlp_contrast_loss(self, x_dis, adj_label):
"""Graph MLP contrastive loss.
Parameters
----------
x_dis : torch.Tensor
Distance matrix.
adj_label : torch.Tensor
Adjacency matrix.
Returns
-------
torch.Tensor
Contrastive loss.
"""
x_dis = torch.exp(self.tau * x_dis)
x_dis_sum = torch.sum(x_dis, 1)
x_dis_sum_pos = torch.sum(x_dis * adj_label, 1)
loss = -torch.log(x_dis_sum_pos * (x_dis_sum ** (-1)) + 1e-8).mean()
return loss

def forward(
self, model_out: dict, batch: torch_geometric.data.Data
) -> torch.Tensor:
r"""Forward pass of the loss function.
Parameters
----------
model_out : dict
Dictionary containing the model output.
batch : torch_geometric.data.Data
Batch object containing the batched domain data.
Returns
-------
dict
Dictionary containing the model output with the loss.
"""
x_dis = model_out["x_dis"]
if x_dis is None: # Validation and test
return torch.tensor(0.0)
adj_label = self.get_power_adj(batch.edge_index)
graph_mlp_loss = self.loss_weight * self.graph_mlp_contrast_loss(
x_dis, adj_label
)
return graph_mlp_loss
7 changes: 7 additions & 0 deletions topobenchmarkx/loss/model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Init file for custom loss module."""

from .GraphMLPLoss import GraphMLPLoss

__all__ = [
"GraphMLPLoss",
]
Loading

0 comments on commit 6b5c349

Please sign in to comment.