-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #112 from geometric-intelligence/loss
Loss
- Loading branch information
Showing
23 changed files
with
516 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
_target_: topobenchmarkx.loss.TBXLoss | ||
task: ${dataset.parameters.task} | ||
loss_type: ${dataset.parameters.loss_type} | ||
dataset_loss: | ||
task: ${dataset.parameters.task} | ||
loss_type: ${dataset.parameters.loss_type} | ||
modules_losses: # Collect model losses | ||
feature_encoder: ${oc.select:model.feature_encoder.loss,null} | ||
backbone: ${oc.select:model.backbone.loss,null} | ||
readout: ${oc.select:model.readout.loss,null} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
_target_: topobenchmarkx.model.TBXModel | ||
|
||
model_name: GraphMLP | ||
model_domain: graph | ||
|
||
feature_encoder: | ||
_target_: topobenchmarkx.nn.encoders.${model.feature_encoder.encoder_name} | ||
encoder_name: AllCellFeatureEncoder | ||
in_channels: ${infer_in_channels:${dataset},${oc.select:transforms,null}} | ||
out_channels: 32 | ||
proj_dropout: 0.0 | ||
|
||
backbone: | ||
_target_: topobenchmarkx.nn.backbones.GraphMLP | ||
in_channels: ${model.feature_encoder.out_channels} | ||
hidden_channels: ${model.feature_encoder.out_channels} | ||
order: 2 | ||
dropout: 0.0 | ||
loss: | ||
_target_: topobenchmarkx.loss.model.GraphMLPLoss | ||
r_adj_power: 2 | ||
tau: 1. | ||
loss_weight: 0.5 | ||
|
||
backbone_wrapper: | ||
_target_: topobenchmarkx.nn.wrappers.GraphMLPWrapper | ||
_partial_: true | ||
wrapper_name: GraphMLPWrapper | ||
out_channels: ${model.feature_encoder.out_channels} | ||
num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} | ||
|
||
readout: | ||
_target_: topobenchmarkx.nn.readouts.${model.readout.readout_name} | ||
readout_name: NoReadOut # Use <NoReadOut> in case readout is not needed Options: PropagateSignalDown | ||
num_cell_dimensions: ${infere_num_cell_dimensions:${oc.select:model.feature_encoder.selected_dimensions,null},${model.feature_encoder.in_channels}} # The highest order of cell dimensions to consider | ||
hidden_dim: ${model.feature_encoder.out_channels} | ||
out_channels: ${dataset.parameters.num_classes} | ||
task_level: ${dataset.parameters.task_level} | ||
pooling_type: sum | ||
|
||
|
||
|
||
# compile model for faster training with pytorch 2.0 | ||
compile: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" Test the TBXEvaluator class.""" | ||
import pytest | ||
import torch | ||
import torch_geometric | ||
|
||
from topobenchmarkx.loss.dataset import DatasetLoss | ||
|
||
class TestDatasetLoss: | ||
""" Test the TBXEvaluator class.""" | ||
|
||
def setup_method(self): | ||
""" Setup the test.""" | ||
dataset_loss = {"task": "classification", "loss_type": "cross_entropy"} | ||
self.dataset1 = DatasetLoss(dataset_loss) | ||
dataset_loss = {"task": "regression", "loss_type": "mse"} | ||
self.dataset2 = DatasetLoss(dataset_loss) | ||
dataset_loss = {"task": "regression", "loss_type": "mae"} | ||
self.dataset3 = DatasetLoss(dataset_loss) | ||
dataset_loss = {"task": "wrong", "loss_type": "wrong"} | ||
with pytest.raises(Exception): | ||
DatasetLoss(dataset_loss) | ||
repr = self.dataset1.__repr__() | ||
assert repr == "DatasetLoss(task=classification, loss_type=cross_entropy)" | ||
|
||
def test_forward(self): | ||
""" Test the forward method.""" | ||
batch = torch_geometric.data.Data() | ||
model_out = {"logits": torch.tensor([0.1, 0.2, 0.3]), "labels": torch.tensor([0.1, 0.2, 0.3])} | ||
out = self.dataset1.forward(model_out, batch) | ||
assert out.item() >= 0 | ||
model_out = {"logits": torch.tensor([0.1, 0.2, 0.3]), "labels": torch.tensor([0.1, 0.2, 0.3])} | ||
out = self.dataset3.forward(model_out, batch) | ||
assert out.item() >= 0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Init file for custom graph models.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Unit tests for GraphMLP.""" | ||
|
||
import torch | ||
import torch_geometric | ||
from topobenchmarkx.nn.backbones.graph import GraphMLP | ||
from topobenchmarkx.nn.wrappers.graph import GraphMLPWrapper | ||
from topobenchmarkx.loss.model import GraphMLPLoss | ||
|
||
def testGraphMLP(random_graph_input): | ||
""" Unit test for GraphMLP. | ||
Parameters | ||
---------- | ||
random_graph_input : Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]] | ||
A tuple of input tensors for testing EDGNN. | ||
""" | ||
x, x_1, x_2, edges_1, edges_2 = random_graph_input | ||
batch = torch_geometric.data.Data(x_0=x, y=x, edge_index=edges_1, batch_0=torch.zeros(x.shape[0], dtype=torch.long)) | ||
model = GraphMLP(x.shape[1], x.shape[1]) | ||
wrapper = GraphMLPWrapper(model, **{"out_channels": x.shape[1], "num_cell_dimensions": 1}) | ||
loss_fn = GraphMLPLoss() | ||
|
||
_ = wrapper.__repr__() | ||
_ = loss_fn.__repr__() | ||
|
||
model_out = wrapper(batch) | ||
assert model_out["x_0"].shape == x.shape | ||
assert list(model_out["x_dis"].shape) == [8,8] | ||
|
||
loss = loss_fn(model_out, batch) | ||
assert loss.item() >= 0 | ||
|
||
model_out["x_dis"] = None | ||
loss = loss_fn(model_out, batch) | ||
assert loss == torch.tensor(0.0) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
"""Loss module for the topobenchmarkx package.""" | ||
|
||
import torch | ||
import torch_geometric | ||
|
||
from topobenchmarkx.loss.base import AbstractLoss | ||
|
||
|
||
class DatasetLoss(AbstractLoss): | ||
r"""Defines the default model loss for the given task. | ||
Parameters | ||
---------- | ||
dataset_loss : dict | ||
Dictionary containing the dataset loss information. | ||
""" | ||
|
||
def __init__(self, dataset_loss): | ||
super().__init__() | ||
self.task = dataset_loss["task"] | ||
self.loss_type = dataset_loss["loss_type"] | ||
# Dataset loss | ||
if self.task == "classification" and self.loss_type == "cross_entropy": | ||
self.criterion = torch.nn.CrossEntropyLoss() | ||
elif self.task == "regression" and self.loss_type == "mse": | ||
self.criterion = torch.nn.MSELoss() | ||
elif self.task == "regression" and self.loss_type == "mae": | ||
self.criterion = torch.nn.L1Loss() | ||
else: | ||
raise Exception("Loss is not defined") | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.__class__.__name__}(task={self.task}, loss_type={self.loss_type})" | ||
|
||
def forward(self, model_out: dict, batch: torch_geometric.data.Data): | ||
r"""Forward pass of the loss function. | ||
Parameters | ||
---------- | ||
model_out : dict | ||
Dictionary containing the model output. | ||
batch : torch_geometric.data.Data | ||
Batch object containing the batched domain data. | ||
Returns | ||
------- | ||
dict | ||
Dictionary containing the model output with the loss. | ||
""" | ||
logits = model_out["logits"] | ||
target = model_out["labels"] | ||
|
||
if self.task == "regression": | ||
target = target.unsqueeze(1) | ||
|
||
dataset_loss = self.criterion(logits, target) | ||
|
||
return dataset_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""Init file for custom loss module.""" | ||
|
||
from .DatasetLoss import DatasetLoss | ||
|
||
__all__ = [ | ||
"DatasetLoss", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
"""Graph MLP loss function.""" | ||
|
||
import torch | ||
import torch_geometric | ||
|
||
from topobenchmarkx.loss.base import AbstractLoss | ||
|
||
|
||
class GraphMLPLoss(AbstractLoss): | ||
r"""Graph MLP loss function. | ||
Parameters | ||
---------- | ||
r_adj_power : int, optional | ||
Power of the adjacency matrix (default: 2). | ||
tau : float, optional | ||
Temperature parameter (default: 1). | ||
loss_weight : float, optional | ||
Loss weight (default: 0.5). | ||
""" | ||
|
||
def __init__(self, r_adj_power=2, tau=1.0, loss_weight=0.5): | ||
super().__init__() | ||
self.r_adj_power = r_adj_power | ||
self.tau = tau | ||
self.loss_weight = loss_weight | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.__class__.__name__}(r_adj_power={self.r_adj_power}, tau={self.tau}, loss_weight={self.loss_weight})" | ||
|
||
def get_power_adj(self, edge_index): | ||
r"""Get the power of the adjacency matrix. | ||
Parameters | ||
---------- | ||
edge_index : torch.Tensor | ||
Edge index tensor. | ||
Returns | ||
------- | ||
torch.Tensor | ||
Power of the adjacency matrix. | ||
""" | ||
edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index) | ||
adj = torch_geometric.utils.to_dense_adj(edge_index) | ||
adj_power = adj.clone() | ||
for _ in range(self.r_adj_power - 1): | ||
adj_power = torch.matmul(adj_power, adj) | ||
return adj_power | ||
|
||
def graph_mlp_contrast_loss(self, x_dis, adj_label): | ||
"""Graph MLP contrastive loss. | ||
Parameters | ||
---------- | ||
x_dis : torch.Tensor | ||
Distance matrix. | ||
adj_label : torch.Tensor | ||
Adjacency matrix. | ||
Returns | ||
------- | ||
torch.Tensor | ||
Contrastive loss. | ||
""" | ||
x_dis = torch.exp(self.tau * x_dis) | ||
x_dis_sum = torch.sum(x_dis, 1) | ||
x_dis_sum_pos = torch.sum(x_dis * adj_label, 1) | ||
loss = -torch.log(x_dis_sum_pos * (x_dis_sum ** (-1)) + 1e-8).mean() | ||
return loss | ||
|
||
def forward( | ||
self, model_out: dict, batch: torch_geometric.data.Data | ||
) -> torch.Tensor: | ||
r"""Forward pass of the loss function. | ||
Parameters | ||
---------- | ||
model_out : dict | ||
Dictionary containing the model output. | ||
batch : torch_geometric.data.Data | ||
Batch object containing the batched domain data. | ||
Returns | ||
------- | ||
dict | ||
Dictionary containing the model output with the loss. | ||
""" | ||
x_dis = model_out["x_dis"] | ||
if x_dis is None: # Validation and test | ||
return torch.tensor(0.0) | ||
adj_label = self.get_power_adj(batch.edge_index) | ||
graph_mlp_loss = self.loss_weight * self.graph_mlp_contrast_loss( | ||
x_dis, adj_label | ||
) | ||
return graph_mlp_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""Init file for custom loss module.""" | ||
|
||
from .GraphMLPLoss import GraphMLPLoss | ||
|
||
__all__ = [ | ||
"GraphMLPLoss", | ||
] |
Oops, something went wrong.