From d14883e82de418cfe6b6622c28a2c964edc99ed1 Mon Sep 17 00:00:00 2001 From: Mohamad Zamini <32536264+mzamini92@users.noreply.github.com> Date: Tue, 6 Jun 2023 11:18:36 -0600 Subject: [PATCH 1/4] Update local_cartesian.py The intermediate tensors `cart` and `max_value` in the original code were replaced with in-place operations to reduce memory usage. This was done by directly operating on the `cart` tensor and computing the maximum value iteratively without creating a separate `max_value` tensor. In-place operations `(torch.sub, cart.div_, cart.mul_, cart.add_)` were used to perform computations directly on tensors, reducing memory usage and eliminating the need for intermediate tensors. To compute the maximum value in a streaming fashion, a loop was introduced in the combined version. This loop iterates over the edges and updates the maximum value tensor (max_value) accordingly. --- torch_geometric/transforms/local_cartesian.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/torch_geometric/transforms/local_cartesian.py b/torch_geometric/transforms/local_cartesian.py index ee7916aeb297..4ca6d67db2ac 100644 --- a/torch_geometric/transforms/local_cartesian.py +++ b/torch_geometric/transforms/local_cartesian.py @@ -1,41 +1,50 @@ import torch - from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform -from torch_geometric.utils import scatter - @functional_transform('local_cartesian') class LocalCartesian(BaseTransform): r"""Saves the relative Cartesian coordinates of linked nodes in its edge attributes (functional name: :obj:`local_cartesian`). Each coordinate gets - *neighborhood-normalized* to the interval :math:`{[0, 1]}^D`. + *neighborhood-normalized* to the specified interval. Args: norm (bool, optional): If set to :obj:`False`, the output will not be - normalized to the interval :math:`{[0, 1]}^D`. - (default: :obj:`True`) + normalized. (default: :obj:`True`) cat (bool, optional): If set to :obj:`False`, all existing edge attributes will be replaced. (default: :obj:`True`) + norm_range (tuple, optional): Tuple specifying the range for normalization. + Each element of the tuple represents the lower and upper bounds for a + coordinate dimension. (default: :obj:`(0, 1)`) """ - def __init__(self, norm: bool = True, cat: bool = True): + def __init__(self, norm=True, cat=True, norm_range=(0, 1)): self.norm = norm self.cat = cat + self.norm_range = norm_range def forward(self, data: Data) -> Data: (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr - cart = pos[row] - pos[col] - cart = cart.view(-1, 1) if cart.dim() == 1 else cart + cart = torch.empty(row.size(0), pos.size(1), device=pos.device) + torch.sub(pos[row], pos[col], out=cart) # In-place subtraction + + max_value = torch.empty(pos.size(0), device=pos.device) + max_value.fill_(float('-inf')) + + for i in range(row.size(0)): + cart_abs = cart[i].abs() + max_value.index_copy_(0, col, torch.maximum(max_value[col], cart_abs)) - max_value = scatter(cart.abs(), col, 0, pos.size(0), reduce='max') - max_value = max_value.max(dim=-1, keepdim=True)[0] + max_value = torch.max(max_value) if self.norm: - cart = cart / (2 * max_value[col]) + 0.5 + norm_range_min, norm_range_max = self.norm_range + norm_factor = 2 * max_value + norm_factor[norm_factor == 0] = 1 # Avoid division by zero + cart.div_(norm_factor).mul_(norm_range_max - norm_range_min).add_(norm_range_min) # In-place normalization else: - cart = cart / max_value[col] + cart = cart if pseudo is not None and self.cat: pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo From c7d1f7e64aaf8dd747ddcb9971c3d38ce04450d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 6 Jun 2023 17:20:47 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/transforms/local_cartesian.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch_geometric/transforms/local_cartesian.py b/torch_geometric/transforms/local_cartesian.py index 4ca6d67db2ac..0062dae654d6 100644 --- a/torch_geometric/transforms/local_cartesian.py +++ b/torch_geometric/transforms/local_cartesian.py @@ -1,8 +1,10 @@ import torch + from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform + @functional_transform('local_cartesian') class LocalCartesian(BaseTransform): r"""Saves the relative Cartesian coordinates of linked nodes in its edge @@ -34,7 +36,8 @@ def forward(self, data: Data) -> Data: for i in range(row.size(0)): cart_abs = cart[i].abs() - max_value.index_copy_(0, col, torch.maximum(max_value[col], cart_abs)) + max_value.index_copy_(0, col, + torch.maximum(max_value[col], cart_abs)) max_value = torch.max(max_value) @@ -42,7 +45,8 @@ def forward(self, data: Data) -> Data: norm_range_min, norm_range_max = self.norm_range norm_factor = 2 * max_value norm_factor[norm_factor == 0] = 1 # Avoid division by zero - cart.div_(norm_factor).mul_(norm_range_max - norm_range_min).add_(norm_range_min) # In-place normalization + cart.div_(norm_factor).mul_(norm_range_max - norm_range_min).add_( + norm_range_min) # In-place normalization else: cart = cart From 75120785bda8a2d6dd1fef9e9253dd00b6b7ead6 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 20 Jun 2023 13:05:38 +0000 Subject: [PATCH 3/4] update --- test/transforms/test_local_cartesian.py | 3 +- torch_geometric/transforms/local_cartesian.py | 47 +++++++++---------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/test/transforms/test_local_cartesian.py b/test/transforms/test_local_cartesian.py index fdc752aaa01e..a6952b0d5ba1 100644 --- a/test/transforms/test_local_cartesian.py +++ b/test/transforms/test_local_cartesian.py @@ -5,7 +5,7 @@ def test_local_cartesian(): - transform = LocalCartesian() + transform = LocalCartesian(interval=[0, 1]) assert str(transform) == 'LocalCartesian()' pos = torch.Tensor([[-1, 0], [0, 0], [2, 0]]) @@ -14,6 +14,7 @@ def test_local_cartesian(): data = Data(edge_index=edge_index, pos=pos) data = transform(data) + assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist() diff --git a/torch_geometric/transforms/local_cartesian.py b/torch_geometric/transforms/local_cartesian.py index 0062dae654d6..d36455ed5c5c 100644 --- a/torch_geometric/transforms/local_cartesian.py +++ b/torch_geometric/transforms/local_cartesian.py @@ -1,54 +1,51 @@ +from typing import Tuple + import torch from torch_geometric.data import Data from torch_geometric.data.datapipes import functional_transform from torch_geometric.transforms import BaseTransform +from torch_geometric.utils import scatter @functional_transform('local_cartesian') class LocalCartesian(BaseTransform): r"""Saves the relative Cartesian coordinates of linked nodes in its edge attributes (functional name: :obj:`local_cartesian`). Each coordinate gets - *neighborhood-normalized* to the specified interval. + *neighborhood-normalized* to a specified interval + (:math:`[0, 1]` by default). Args: norm (bool, optional): If set to :obj:`False`, the output will not be normalized. (default: :obj:`True`) cat (bool, optional): If set to :obj:`False`, all existing edge attributes will be replaced. (default: :obj:`True`) - norm_range (tuple, optional): Tuple specifying the range for normalization. - Each element of the tuple represents the lower and upper bounds for a - coordinate dimension. (default: :obj:`(0, 1)`) + interval ((float, float), optional): A tuple specifying the lower and + upper bound for normalization. (default: :obj:`(0.0, 1.0)`) """ - def __init__(self, norm=True, cat=True, norm_range=(0, 1)): + def __init__( + self, + norm: bool = True, + cat: bool = True, + interval: Tuple[float, float] = (0.0, 1.0), + ): self.norm = norm self.cat = cat - self.norm_range = norm_range + self.interval = interval def forward(self, data: Data) -> Data: (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr - cart = torch.empty(row.size(0), pos.size(1), device=pos.device) - torch.sub(pos[row], pos[col], out=cart) # In-place subtraction - - max_value = torch.empty(pos.size(0), device=pos.device) - max_value.fill_(float('-inf')) - - for i in range(row.size(0)): - cart_abs = cart[i].abs() - max_value.index_copy_(0, col, - torch.maximum(max_value[col], cart_abs)) - - max_value = torch.max(max_value) + cart = pos[row] - pos[col] + cart = cart.view(-1, 1) if cart.dim() == 1 else cart if self.norm: - norm_range_min, norm_range_max = self.norm_range - norm_factor = 2 * max_value - norm_factor[norm_factor == 0] = 1 # Avoid division by zero - cart.div_(norm_factor).mul_(norm_range_max - norm_range_min).add_( - norm_range_min) # In-place normalization - else: - cart = cart + max_value = scatter(cart.abs(), col, 0, pos.size(0), reduce='max') + max_value = max_value.max(dim=-1, keepdim=True)[0] + + length = self.interval[1] - self.interval[0] + center = (self.interval[0] + self.interval[1]) / 2 + cart = length * cart / (2 * max_value[col]) + center if pseudo is not None and self.cat: pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo From ae030f0db35eeb7c5ed87c7a8780c2880018262a Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 20 Jun 2023 13:06:35 +0000 Subject: [PATCH 4/4] update --- CHANGELOG.md | 1 + test/transforms/test_local_cartesian.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ba7d97a37c7..6b03a9f37899 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `interval` argument to `LocalCartesian` transformation ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533)) - Enabled `LinkNeighborLoader` to return number of sampled nodes and edges per hop ([#7516](https://github.com/pyg-team/pytorch_geometric/pull/7516)) - Added the `HM` personalized fashion recommendation dataset ([#7515](https://github.com/pyg-team/pytorch_geometric/pull/7515)) - Added the `GraphMixer` model ([#7501](https://github.com/pyg-team/pytorch_geometric/pull/7501)) diff --git a/test/transforms/test_local_cartesian.py b/test/transforms/test_local_cartesian.py index a6952b0d5ba1..fdc752aaa01e 100644 --- a/test/transforms/test_local_cartesian.py +++ b/test/transforms/test_local_cartesian.py @@ -5,7 +5,7 @@ def test_local_cartesian(): - transform = LocalCartesian(interval=[0, 1]) + transform = LocalCartesian() assert str(transform) == 'LocalCartesian()' pos = torch.Tensor([[-1, 0], [0, 0], [2, 0]]) @@ -14,7 +14,6 @@ def test_local_cartesian(): data = Data(edge_index=edge_index, pos=pos) data = transform(data) - assert len(data) == 3 assert data.pos.tolist() == pos.tolist() assert data.edge_index.tolist() == edge_index.tolist()