Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add immutable transforms #7429

Merged
merged 3 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- All transforms are now immutable, i.e., they perform a shallow-copy of the data and therefore do not longer modify data in-place ([#7429](https://github.com/pyg-team/pytorch_geometric/pull/7429))
- Set `output_size` in the `repeat_interleave` operation in `QuantileAggregation` ([#7426](https://github.com/pyg-team/pytorch_geometric/pull/7426))
- Fixed gradient computation of edge weights in `utils.spmm` ([#7428](https://github.com/pyg-team/pytorch_geometric/pull/7428))
- Re-factored `ClusterLoader` to integrate `pyg-lib` METIS routine ([#7416](https://github.com/pyg-team/pytorch_geometric/pull/7416))
Expand Down
2 changes: 2 additions & 0 deletions examples/colors_topk_pool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os.path as osp

import torch
Expand All @@ -14,6 +15,7 @@

class HandleNodeAttention:
def __call__(self, data):
data = copy.copy(data)
data.attn = torch.softmax(data.x[:, 0], dim=0)
data.x = data.x[:, 1:]
return data
Expand Down
14 changes: 6 additions & 8 deletions examples/proteins_diff_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@

max_nodes = 150


class MyFilter:
def __call__(self, data):
return data.num_nodes <= max_nodes


path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
'PROTEINS_dense')
dataset = TUDataset(path, name='PROTEINS', transform=T.ToDense(max_nodes),
pre_filter=MyFilter())
dataset = TUDataset(
path,
name='PROTEINS',
transform=T.ToDense(max_nodes),
pre_filter=lambda data: data.num_nodes <= max_nodes,
)
dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
Expand Down
6 changes: 4 additions & 2 deletions examples/qm9_nn_conv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os.path as osp

import torch
Expand All @@ -16,13 +17,14 @@

class MyTransform:
def __call__(self, data):
# Specify target.
data.y = data.y[:, target]
data = copy.copy(data)
data.y = data.y[:, target] # Specify target.
return data


class Complete:
def __call__(self, data):
data = copy.copy(data)
device = data.edge_index.device

row = torch.arange(data.num_nodes, dtype=torch.long, device=device)
Expand Down
2 changes: 2 additions & 0 deletions examples/triangles_sag_pool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os.path as osp

import torch
Expand All @@ -15,6 +16,7 @@

class HandleNodeAttention:
def __call__(self, data):
data = copy.copy(data)
data.attn = torch.softmax(data.x, dim=0).flatten()
data.x = None
return data
Expand Down
1 change: 1 addition & 0 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def len(self) -> int:
return 0

def get(self, idx: int) -> Data:
# TODO (matthias) Avoid unnecessary copy here.
if self.len() == 1:
return copy.copy(self._data)

Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/transforms/add_metapaths.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
self.max_sample = max_sample
self.weighted = weighted

def __call__(self, data: HeteroData) -> HeteroData:
def forward(self, data: HeteroData) -> HeteroData:
edge_types = data.edge_types # save original edge types
data.metapath_dict = {}

Expand Down Expand Up @@ -244,7 +244,7 @@ def __init__(
assert len(walks_per_node) == len(metapaths)
self.walks_per_node = walks_per_node

def __call__(self, data: HeteroData) -> HeteroData:
def forward(self, data: HeteroData) -> HeteroData:
edge_types = data.edge_types # save original edge types
data.metapath_dict = {}

Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/transforms/add_positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self.is_undirected = is_undirected
self.kwargs = kwargs

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
from scipy.sparse.linalg import eigs, eigsh
eig_fn = eigs if not self.is_undirected else eigsh

Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
self.walk_length = walk_length
self.attr_name = attr_name

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
row, col = data.edge_index
N = data.num_nodes

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/add_remaining_self_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, attr: Optional[str] = 'edge_weight',
self.attr = attr
self.fill_value = fill_value

def __call__(
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/add_self_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, attr: Optional[str] = 'edge_weight',
self.attr = attr
self.fill_value = fill_value

def __call__(
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
Expand Down
9 changes: 7 additions & 2 deletions torch_geometric/transforms/base_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
import copy
from abc import ABC, abstractmethod
from typing import Any


Expand Down Expand Up @@ -27,7 +28,11 @@ class BaseTransform(ABC):
data = transform(data) # Explicitly transform data.
"""
def __call__(self, data: Any) -> Any:
raise NotImplementedError
return self.forward(copy.copy(data))

@abstractmethod
def forward(self, data: Any) -> Any:
pass

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
2 changes: 1 addition & 1 deletion torch_geometric/transforms/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
self.max = max_value
self.cat = cat

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
(row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr

cart = pos[row] - pos[col]
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/center.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class Center(BaseTransform):
r"""Centers node positions :obj:`data.pos` around the origin
(functional name: :obj:`center`)."""
def __call__(
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class Compose(BaseTransform):
def __init__(self, transforms: List[Callable]):
self.transforms = transforms

def __call__(
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self.cat = cat
self.node_types = node_types

def __call__(
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/delaunay.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Delaunay(BaseTransform):
r"""Computes the delaunay triangulation of a set of points
(functional name: :obj:`delaunay`)."""
def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
if data.pos.size(0) < 2:
data.edge_index = torch.tensor([], dtype=torch.long,
device=data.pos.device).view(2, 0)
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, norm: bool = True, max_value: Optional[float] = None,
self.max = max_value
self.cat = cat

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
(row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr

dist = torch.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/face_to_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class FaceToEdge(BaseTransform):
def __init__(self, remove_faces: bool = True):
self.remove_faces = remove_faces

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
if hasattr(data, 'face'):
face = data.face
edge_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/feature_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, missing_mask: Tensor, num_iterations: int = 40):
self.missing_mask = missing_mask
self.num_iterations = num_iterations

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
assert 'edge_index' in data or 'adj_t' in data
assert data.x.size() == self.missing_mask.size()

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/fixed_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self.replace = replace
self.allow_duplicates = allow_duplicates

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
num_nodes = data.num_nodes

if self.replace:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/gcn_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class GCNNorm(BaseTransform):
def __init__(self, add_self_loops: bool = True):
self.add_self_loops = add_self_loops

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
gcn_norm = torch_geometric.nn.conv.gcn_conv.gcn_norm
assert 'edge_index' in data or 'adj_t' in data

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/gdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
assert exact or self_loop_weight == 1

@torch.no_grad()
def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
N = data.num_nodes
edge_index = data.edge_index
if data.edge_attr is None:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/generate_mesh_normals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class GenerateMeshNormals(BaseTransform):
r"""Generate normal vectors for each mesh node based on neighboring
faces (functional name: :obj:`generate_mesh_normals`)."""
def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
assert 'face' in data
pos, face = data.pos, data.face

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/grid_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, size: Union[float, List[float], Tensor],
self.start = start
self.end = end

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
num_nodes = data.num_nodes

batch = data.get('batch', None)
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/knn_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
self.cosine = cosine
self.num_workers = num_workers

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
data.edge_attr = None
batch = data.batch if 'batch' in data else None

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/laplacian_lambda_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self.normalization = normalization
self.is_undirected = is_undirected

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
edge_weight = data.edge_attr
if edge_weight is not None and edge_weight.numel() != data.num_edges:
edge_weight = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, num_components: int = 1, connection: str = 'weak'):
self.num_components = num_components
self.connection = connection

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
import numpy as np
import scipy.sparse as sp

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/line_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class LineGraph(BaseTransform):
def __init__(self, force_directed: bool = False):
self.force_directed = force_directed

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
N = data.num_nodes
edge_index, edge_attr = data.edge_index, data.edge_attr
edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes=N)
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/transforms/linear_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def __init__(self, matrix: Tensor):
f'Transformation matrix should be square (got {matrix.size()})')

# Store the matrix as its transpose.
# We do this to enable post-multiplication in `__call__`.
# We do this to enable post-multiplication in `forward`.
self.matrix = matrix.t()

def __call__(
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/local_cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, norm: bool = True, cat: bool = True):
self.norm = norm
self.cat = cat

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
(row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr

cart = pos[row] - pos[col]
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/local_degree_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self):
from torch_geometric.nn.aggr.fused import FusedAggregation
self.aggr = FusedAggregation(['min', 'max', 'mean', 'std'])

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
row, col = data.edge_index
N = data.num_nodes

Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/transforms/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self.sizes = sizes
self.replace = replace

def __call__(
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(
self.attrs = [attrs] if isinstance(attrs, str) else attrs
self.replace = replace

def __call__(
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
Expand Down
3 changes: 1 addition & 2 deletions torch_geometric/transforms/node_property_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def __init__(
self.ratios = ratios
self.ascending = ascending

def __call__(self, data: Data) -> Data:

def forward(self, data: Data) -> Data:
G = to_networkx(data, to_undirected=True, remove_self_loops=True)
property_values = self.compute_fn(G, self.ascending)
mask_dict = self._mask_nodes_by_property(property_values, self.ratios)
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/normalize_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class NormalizeFeatures(BaseTransform):
def __init__(self, attrs: List[str] = ["x"]):
self.attrs = attrs

def __call__(
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/normalize_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, max_points: int = -1, sort: bool = False):
self.max_points = max_points
self.sort = sort

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
pos = data.pos

if self.max_points > 0 and pos.size(0) > self.max_points:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/normalize_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class NormalizeScale(BaseTransform):
def __init__(self):
self.center = Center()

def __call__(self, data: Data) -> Data:
def forward(self, data: Data) -> Data:
data = self.center(data)

scale = (1 / data.pos.abs().max()) * 0.999999
Expand Down
Loading