diff --git a/examples/dgcnn_segmentation.py b/examples/dgcnn_segmentation.py index f9d6d17dc0a6..9b0581f98f13 100644 --- a/examples/dgcnn_segmentation.py +++ b/examples/dgcnn_segmentation.py @@ -2,12 +2,13 @@ import torch import torch.nn.functional as F +from torch_scatter import scatter +from torchmetrics.functional import jaccard_index import torch_geometric.transforms as T from torch_geometric.datasets import ShapeNet from torch_geometric.loader import DataLoader from torch_geometric.nn import MLP, DynamicEdgeConv -from torch_geometric.utils import intersection_and_union as i_and_u category = 'Airplane' # Pass in `None` to train on all categories. path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet') @@ -80,24 +81,32 @@ def train(): def test(loader): model.eval() - y_mask = loader.dataset.y_mask - ious = [[] for _ in range(len(loader.dataset.categories))] - + ious, categories = [], [] + y_map = torch.empty(loader.dataset.num_classes, device=device).long() for data in loader: data = data.to(device) - pred = model(data).argmax(dim=1) + outs = model(data) + + sizes = (data.ptr[1:] - data.ptr[:-1]).tolist() + for out, y, category in zip(outs.split(sizes), data.y.split(sizes), + data.category.tolist()): + category = list(ShapeNet.seg_classes.keys())[category] + part = ShapeNet.seg_classes[category] + part = torch.tensor(part, device=device) + + y_map[part] = torch.arange(part.size(0), device=device) + + iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y], + num_classes=part.size(0), absent_score=1.0) + ious.append(iou) - i, u = i_and_u(pred, data.y, loader.dataset.num_classes, data.batch) - iou = i.cpu().to(torch.float) / u.cpu().to(torch.float) - iou[torch.isnan(iou)] = 1 + categories.append(data.category) - # Find and filter the relevant classes for each category. - for iou, category in zip(iou.unbind(), data.category.unbind()): - ious[category.item()].append(iou[y_mask[category]]) + iou = torch.tensor(ious, device=device) + category = torch.cat(categories, dim=0) - # Compute mean IoU. - ious = [torch.stack(iou).mean(0).mean(0) for iou in ious] - return torch.tensor(ious).mean().item() + mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU. + return float(mean_iou.mean()) # Global IoU. for epoch in range(1, 31): diff --git a/examples/point_transformer_segmentation.py b/examples/point_transformer_segmentation.py index 4db7756d1e8d..3ecfa724ed29 100644 --- a/examples/point_transformer_segmentation.py +++ b/examples/point_transformer_segmentation.py @@ -11,12 +11,13 @@ from torch.nn import ReLU from torch.nn import Sequential as Seq from torch_cluster import knn_graph +from torch_scatter import scatter +from torchmetrics.functional import jaccard_index import torch_geometric.transforms as T from torch_geometric.datasets import ShapeNet from torch_geometric.loader import DataLoader from torch_geometric.nn.unpool import knn_interpolate -from torch_geometric.utils import intersection_and_union as i_and_u category = 'Airplane' # Pass in `None` to train on all categories. path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet') @@ -197,24 +198,32 @@ def train(): def test(loader): model.eval() - y_mask = loader.dataset.y_mask - ious = [[] for _ in range(len(loader.dataset.categories))] - + ious, categories = [], [] + y_map = torch.empty(loader.dataset.num_classes, device=device).long() for data in loader: data = data.to(device) - pred = model(data.x, data.pos, data.batch).argmax(dim=1) + outs = model(data.x, data.pos, data.batch) + + sizes = (data.ptr[1:] - data.ptr[:-1]).tolist() + for out, y, category in zip(outs.split(sizes), data.y.split(sizes), + data.category.tolist()): + category = list(ShapeNet.seg_classes.keys())[category] + part = ShapeNet.seg_classes[category] + part = torch.tensor(part, device=device) + + y_map[part] = torch.arange(part.size(0), device=device) + + iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y], + num_classes=part.size(0), absent_score=1.0) + ious.append(iou) - i, u = i_and_u(pred, data.y, loader.dataset.num_classes, data.batch) - iou = i.cpu().to(torch.float) / u.cpu().to(torch.float) - iou[torch.isnan(iou)] = 1 + categories.append(data.category) - # Find and filter the relevant classes for each category. - for iou, category in zip(iou.unbind(), data.category.unbind()): - ious[category.item()].append(iou[y_mask[category]]) + iou = torch.tensor(ious, device=device) + category = torch.cat(categories, dim=0) - # Compute mean IoU. - ious = [torch.stack(iou).mean(0).mean(0) for iou in ious] - return torch.tensor(ious).mean().item() + mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU. + return float(mean_iou.mean()) # Global IoU. for epoch in range(1, 100): diff --git a/examples/pointnet2_segmentation.py b/examples/pointnet2_segmentation.py index ee20c5637a0e..15a50723ea47 100644 --- a/examples/pointnet2_segmentation.py +++ b/examples/pointnet2_segmentation.py @@ -3,12 +3,13 @@ import torch import torch.nn.functional as F from pointnet2_classification import GlobalSAModule, SAModule +from torch_scatter import scatter +from torchmetrics.functional import jaccard_index import torch_geometric.transforms as T from torch_geometric.datasets import ShapeNet from torch_geometric.loader import DataLoader from torch_geometric.nn import MLP, knn_interpolate -from torch_geometric.utils import intersection_and_union as i_and_u category = 'Airplane' # Pass in `None` to train on all categories. path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet') @@ -106,24 +107,32 @@ def train(): def test(loader): model.eval() - y_mask = loader.dataset.y_mask - ious = [[] for _ in range(len(loader.dataset.categories))] - + ious, categories = [], [] + y_map = torch.empty(loader.dataset.num_classes, device=device).long() for data in loader: data = data.to(device) - pred = model(data).argmax(dim=1) + outs = model(data) + + sizes = (data.ptr[1:] - data.ptr[:-1]).tolist() + for out, y, category in zip(outs.split(sizes), data.y.split(sizes), + data.category.tolist()): + category = list(ShapeNet.seg_classes.keys())[category] + part = ShapeNet.seg_classes[category] + part = torch.tensor(part, device=device) + + y_map[part] = torch.arange(part.size(0), device=device) + + iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y], + num_classes=part.size(0), absent_score=1.0) + ious.append(iou) - i, u = i_and_u(pred, data.y, loader.dataset.num_classes, data.batch) - iou = i.cpu().to(torch.float) / u.cpu().to(torch.float) - iou[torch.isnan(iou)] = 1 + categories.append(data.category) - # Find and filter the relevant classes for each category. - for iou, category in zip(iou.unbind(), data.category.unbind()): - ious[category.item()].append(iou[y_mask[category]]) + iou = torch.tensor(ious, device=device) + category = torch.cat(categories, dim=0) - # Compute mean IoU. - ious = [torch.stack(iou).mean(0).mean(0) for iou in ious] - return torch.tensor(ious).mean().item() + mean_iou = scatter(iou, category, reduce='mean') # Per-category IoU. + return float(mean_iou.mean()) # Global IoU. for epoch in range(1, 31): diff --git a/setup.py b/setup.py index 4640a6806c31..f6941828eee8 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ 'matplotlib', 'scikit-image', 'pytorch-memlab', + 'torchmetrics>=0.7', 'class-resolver>=0.3.2', ] diff --git a/test/utils/test_metric.py b/test/utils/test_metric.py deleted file mode 100644 index 834efbccc3ca..000000000000 --- a/test/utils/test_metric.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch - -from torch_geometric.utils import ( - accuracy, - f1_score, - false_negative, - false_positive, - mean_iou, - precision, - recall, - true_negative, - true_positive, -) - - -def test_metric(): - pred = torch.tensor([0, 0, 1, 1]) - target = torch.tensor([0, 1, 0, 1]) - - assert accuracy(pred, target) == 0.5 - assert true_positive(pred, target, num_classes=2).tolist() == [1, 1] - assert true_negative(pred, target, num_classes=2).tolist() == [1, 1] - assert false_positive(pred, target, num_classes=2).tolist() == [1, 1] - assert false_negative(pred, target, num_classes=2).tolist() == [1, 1] - assert precision(pred, target, num_classes=2).tolist() == [0.5, 0.5] - assert recall(pred, target, num_classes=2).tolist() == [0.5, 0.5] - assert f1_score(pred, target, num_classes=2).tolist() == [0.5, 0.5] - - -def test_mean_iou(): - pred = torch.tensor([0, 0, 1, 1, 0, 1]) - target = torch.tensor([0, 1, 0, 1, 0, 0]) - - out = mean_iou(pred, target, num_classes=2) - assert out == (0.4 + 0.25) / 2 - - batch = torch.tensor([0, 0, 0, 0, 1, 1]) - out = mean_iou(pred, target, num_classes=2, batch=batch) - assert out.size() == (2, ) - assert out[0] == (1 / 3 + 1 / 3) / 2 - assert out[1] == 0.25 diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index d4446a9d35ee..81a397afa297 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -32,9 +32,6 @@ structured_negative_sampling, structured_negative_sampling_feasible) from .train_test_split_edges import train_test_split_edges -from .metric import (accuracy, true_positive, true_negative, false_positive, - false_negative, precision, recall, f1_score, - intersection_and_union, mean_iou) __all__ = [ 'degree', @@ -82,16 +79,6 @@ 'structured_negative_sampling', 'structured_negative_sampling_feasible', 'train_test_split_edges', - 'accuracy', - 'true_positive', - 'true_negative', - 'false_positive', - 'false_negative', - 'precision', - 'recall', - 'f1_score', - 'intersection_and_union', - 'mean_iou', ] classes = __all__ diff --git a/torch_geometric/utils/metric.py b/torch_geometric/utils/metric.py deleted file mode 100644 index 8479501309a5..000000000000 --- a/torch_geometric/utils/metric.py +++ /dev/null @@ -1,201 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F -from torch import Tensor -from torch_scatter import scatter_add - - -def accuracy(pred: Tensor, target: Tensor) -> float: - r"""Computes the accuracy of predictions. - - Args: - pred (Tensor): The predictions. - target (Tensor): The targets. - - :rtype: float - """ - return int((pred == target).sum()) / target.numel() - - -def true_positive(pred: Tensor, target: Tensor, num_classes: int) -> Tensor: - r"""Computes the number of true positive predictions. - - Args: - pred (Tensor): The predictions. - target (Tensor): The targets. - num_classes (int): The number of classes. - - :rtype: :class:`LongTensor` - """ - out = [] - for i in range(num_classes): - out.append(((pred == i) & (target == i)).sum()) - - return torch.tensor(out, device=pred.device) - - -def true_negative(pred: Tensor, target: Tensor, num_classes: int) -> Tensor: - r"""Computes the number of true negative predictions. - - Args: - pred (Tensor): The predictions. - target (Tensor): The targets. - num_classes (int): The number of classes. - - :rtype: :class:`LongTensor` - """ - out = [] - for i in range(num_classes): - out.append(((pred != i) & (target != i)).sum()) - - return torch.tensor(out, device=pred.device) - - -def false_positive(pred: Tensor, target: Tensor, num_classes: int) -> Tensor: - r"""Computes the number of false positive predictions. - - Args: - pred (Tensor): The predictions. - target (Tensor): The targets. - num_classes (int): The number of classes. - - :rtype: :class:`LongTensor` - """ - out = [] - for i in range(num_classes): - out.append(((pred == i) & (target != i)).sum()) - - return torch.tensor(out, device=pred.device) - - -def false_negative(pred: Tensor, target: Tensor, num_classes: int) -> Tensor: - r"""Computes the number of false negative predictions. - - Args: - pred (Tensor): The predictions. - target (Tensor): The targets. - num_classes (int): The number of classes. - - :rtype: :class:`LongTensor` - """ - out = [] - for i in range(num_classes): - out.append(((pred != i) & (target == i)).sum()) - - return torch.tensor(out, device=pred.device) - - -def precision(pred: Tensor, target: Tensor, num_classes: int) -> Tensor: - r"""Computes the precision - :math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FP}}` of predictions. - - Args: - pred (Tensor): The predictions. - target (Tensor): The targets. - num_classes (int): The number of classes. - - :rtype: :class:`Tensor` - """ - tp = true_positive(pred, target, num_classes).to(torch.float) - fp = false_positive(pred, target, num_classes).to(torch.float) - - out = tp / (tp + fp) - out[torch.isnan(out)] = 0 - - return out - - -def recall(pred: Tensor, target: Tensor, num_classes: int) -> Tensor: - r"""Computes the recall - :math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FN}}` of predictions. - - Args: - pred (Tensor): The predictions. - target (Tensor): The targets. - num_classes (int): The number of classes. - - :rtype: :class:`Tensor` - """ - tp = true_positive(pred, target, num_classes).to(torch.float) - fn = false_negative(pred, target, num_classes).to(torch.float) - - out = tp / (tp + fn) - out[torch.isnan(out)] = 0 - - return out - - -def f1_score(pred: Tensor, target: Tensor, num_classes: int) -> Tensor: - r"""Computes the :math:`F_1` score - :math:`2 \cdot \frac{\mathrm{precision} \cdot \mathrm{recall}} - {\mathrm{precision}+\mathrm{recall}}` of predictions. - - Args: - pred (Tensor): The predictions. - target (Tensor): The targets. - num_classes (int): The number of classes. - - :rtype: :class:`Tensor` - """ - prec = precision(pred, target, num_classes) - rec = recall(pred, target, num_classes) - - score = 2 * (prec * rec) / (prec + rec) - score[torch.isnan(score)] = 0 - - return score - - -def intersection_and_union( - pred: Tensor, target: Tensor, num_classes: int, - batch: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: - r"""Computes intersection and union of predictions. - - Args: - pred (LongTensor): The predictions. - target (LongTensor): The targets. - num_classes (int): The number of classes. - batch (LongTensor): The assignment vector which maps each pred-target - pair to an example. - - :rtype: (:class:`LongTensor`, :class:`LongTensor`) - """ - pred, target = F.one_hot(pred, num_classes), F.one_hot(target, num_classes) - - if batch is None: - i = (pred & target).sum(dim=0) - u = (pred | target).sum(dim=0) - else: - i = scatter_add(pred & target, batch, dim=0) - u = scatter_add(pred | target, batch, dim=0) - - return i, u - - -def mean_iou(pred: Tensor, target: Tensor, num_classes: int, - batch: Optional[Tensor] = None, omitnans: bool = False) -> Tensor: - r"""Computes the mean intersection over union score of predictions. - - Args: - pred (LongTensor): The predictions. - target (LongTensor): The targets. - num_classes (int): The number of classes. - batch (LongTensor): The assignment vector which maps each pred-target - pair to an example. - omitnans (bool, optional): If set to :obj:`True`, will ignore any - :obj:`NaN` values encountered during computation. Otherwise, will - treat them as :obj:`1`. (default: :obj:`False`) - - :rtype: :class:`Tensor` - """ - i, u = intersection_and_union(pred, target, num_classes, batch) - iou = i.to(torch.float) / u.to(torch.float) - - if omitnans: - iou = iou[~iou.isnan()].mean() - else: - iou[torch.isnan(iou)] = 1. - iou = iou.mean(dim=-1) - - return iou