Skip to content

Commit

Permalink
Merge branch 'pyg-team:master' into piotrc/disable_dynamic_shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrchmiel authored May 16, 2023
2 parents 5256481 + 239c4b5 commit a92cb50
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 37 deletions.
12 changes: 8 additions & 4 deletions torch_geometric/nn/conv/rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
torch_sparse,
)
from torch_geometric.utils import index_sort, one_hot, scatter, spmm
from torch_geometric.utils.hetero import segmatmul_heuristic
from torch_geometric.utils.sparse import index2ptr


Expand Down Expand Up @@ -126,7 +127,7 @@ def __init__(
self.num_bases = num_bases
self.num_blocks = num_blocks
self.is_sorted = is_sorted

self.use_segmm: int = -1
if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
self.in_channels_l = in_channels[0]
Expand Down Expand Up @@ -201,7 +202,6 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
x_r = x[1]

size = (x_l.size(0), x_r.size(0))

if isinstance(edge_index, SparseTensor):
edge_type = edge_index.storage.value()
assert edge_type is not None
Expand Down Expand Up @@ -230,14 +230,18 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],

else: # No regularization/Basis-decomposition ========================
if (torch_geometric.typing.WITH_PYG_LIB and self.num_bases is None
and x_l.is_floating_point()
and isinstance(edge_index, Tensor)):
and x_l.is_floating_point() and isinstance(
edge_index, Tensor)) and (self.use_segmm == -1
or bool(self.use_segmm)):
if not self.is_sorted:
if (edge_type[1:] < edge_type[:-1]).any():
edge_type, perm = index_sort(
edge_type, max_value=self.num_relations)
edge_index = edge_index[:, perm]
edge_type_ptr = index2ptr(edge_type, self.num_relations)
if self.use_segmm == -1:
self.use_segmm = segmatmul_heuristic(
x_l, edge_type_ptr, self.weight)
out = self.propagate(edge_index, x=x_l,
edge_type_ptr=edge_type_ptr, size=size)
else:
Expand Down
59 changes: 27 additions & 32 deletions torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch_geometric.nn import inits
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import index_sort
from torch_geometric.utils.hetero import segmatmul_heuristic
from torch_geometric.utils.sparse import index2ptr


Expand Down Expand Up @@ -215,48 +216,36 @@ def __init__(self, in_channels: int, out_channels: int, num_types: int,
self.num_types = num_types
self.is_sorted = is_sorted
self.kwargs = kwargs

if torch_geometric.typing.WITH_PYG_LIB:
self.lins = None
if self.in_channels == -1:
self.weight = nn.parameter.UninitializedParameter()
self._hook = self.register_forward_pre_hook(
self.initialize_parameters)
else:
self.weight = torch.nn.Parameter(
torch.Tensor(num_types, in_channels, out_channels))
if kwargs.get('bias', True):
self.bias = Parameter(torch.Tensor(num_types, out_channels))
else:
self.register_parameter('bias', None)
self.use_segmm: int = -1
if self.in_channels == -1:
self.weight = nn.parameter.UninitializedParameter()
self._hook = self.register_forward_pre_hook(
self.initialize_parameters)
else:
self.weight = torch.nn.Parameter(
torch.Tensor(num_types, in_channels, out_channels))
if kwargs.get('bias', True):
self.bias = Parameter(torch.Tensor(num_types, out_channels))
else:
self.lins = torch.nn.ModuleList([
Linear(in_channels, out_channels, **kwargs)
for _ in range(num_types)
])
self.register_parameter('weight', None)
self.register_parameter('bias', None)

self.reset_parameters()

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
if torch_geometric.typing.WITH_PYG_LIB:
reset_weight_(self.weight, self.in_channels,
self.kwargs.get('weight_initializer', None))
reset_weight_(self.bias, self.in_channels,
self.kwargs.get('bias_initializer', None))
else:
for lin in self.lins:
lin.reset_parameters()
reset_weight_(self.weight, self.in_channels,
self.kwargs.get('weight_initializer', None))
reset_weight_(self.bias, self.in_channels,
self.kwargs.get('bias_initializer', None))

def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
r"""
Args:
x (torch.Tensor): The input features.
type_vec (torch.Tensor): A vector that maps each entry to a type.
"""
if torch_geometric.typing.WITH_PYG_LIB:

if torch_geometric.typing.WITH_PYG_LIB and (self.use_segmm == -1
or bool(self.use_segmm)):
assert self.weight is not None

perm: Optional[Tensor] = None
Expand All @@ -266,6 +255,9 @@ def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
x = x[perm]

type_vec_ptr = index2ptr(type_vec, self.num_types)
if self.use_segmm == -1:
self.use_segmm = segmatmul_heuristic(x, type_vec_ptr,
self.weight)
out = pyg_lib.ops.segment_matmul(x, type_vec_ptr, self.weight)
if self.bias is not None:
out += self.bias[type_vec]
Expand All @@ -275,11 +267,14 @@ def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
out_unsorted[perm] = out
out = out_unsorted
else:
assert self.lins is not None
out = x.new_empty(x.size(0), self.out_channels)
for i, lin in enumerate(self.lins):
for i in range(self.num_types):
mask = type_vec == i
out[mask] = lin(x[mask])
if mask.numel() == 0:
continue
out[mask] = F.linear(x[mask], self.weight[i].T)
if self.bias is not None:
out += self.bias[type_vec]
return out

@torch.no_grad()
Expand Down
4 changes: 3 additions & 1 deletion torch_geometric/nn/to_hetero_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, List, Optional, Union

import torch
import torch.nn.functional as F
from torch import Tensor

import torch_geometric
Expand Down Expand Up @@ -54,7 +55,8 @@ def dict_forward(

if not torch_geometric.typing.WITH_PYG_LIB:
return {
key: self.hetero_module.lins[i](x_dict[key])
key: F.linear(x_dict[key], self.hetero_module.weight[i].T) +
self.hetero_module.bias[i]
for i, key in enumerate(self.types)
}

Expand Down
103 changes: 103 additions & 0 deletions torch_geometric/utils/hetero.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,109 @@
from torch_geometric.utils.num_nodes import maybe_num_nodes_dict


def learn_sklearn_heuristic():
import os
import time

from torch_geometric.nn.dense import HeteroLinear, Linear
os.environ['NVIDIA_TF32_OVERRIDE'] = '0'
fused_times = {}
loop_times = {}
try:
for num_nodes_per_type in [10**2, 10**3, 10**4, 10**5]:
for out_feats in [2, 4, 8, 16, 32, 64, 128, 256]:
for n_feats in [4, 8, 16, 32, 64, 128, 256, 512]:
for num_types in [4, 8, 16, 32, 64, 128, 256, 512]:
try:
if n_feats < out_feats:
continue
print("benchmarking", num_types, "types w/",
num_nodes_per_type, "nodes per type and",
n_feats, "input features and", out_feats,
"outuput feats")
x_dict = {
'v' + str(i): torch.randn(
(num_nodes_per_type, n_feats)).cuda()
for i in range(num_types)
}
x = torch.cat(list(x_dict.values()), dim=0)
node_type = torch.cat([
(j * torch.ones(x_j.shape[0])).long()
for j, x_j in enumerate(x_dict.values())
]).cuda()
lin = Linear(n_feats, out_feats).cuda()
heterolin = HeteroLinear(n_feats, out_feats,
len(list(x_dict.keys())),
True).cuda()
for i in range(60):
if i == 10:
since = time.time()
heterolin(x=x, type_vec=node_type)
key = (num_types, num_nodes_per_type, n_feats,
out_feats)
fused_times[key] = ((time.time() - since) / 50.0)
print("Avg time for fuse based=", fused_times[key])
for i in range(60):
if i == 10:
since = time.time()
o = x.new_empty(x.size(0), out_feats)
for j in range(num_types):
mask = j == node_type
o[mask] = lin(x[mask])
loop_times[key] = ((time.time() - since) / 50.0)
print("Avg time for for-loop=", loop_times[key])
except: # noqa
continue
except: # noqa
pass
import numpy as np
X = np.zeros((len(loop_times), 4))
y = np.zeros(len(loop_times))
for i, key in enumerate(loop_times.keys()):
X[i, :] = key
loop_time, fused_time = loop_times[key], fused_times[key]
y[i] = int(fused_time <= loop_time)
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
scaler = StandardScaler()
svm = LinearSVC()
clf = make_pipeline(scaler, svm)
clf.fit(X, y)

print("scaler mean=", scaler.mean_)
print("scaler scale=", scaler.scale_)
print("svm weights=", svm.coef_)
print("svm bias=", svm.intercept_)
# results on A100:
# scaler mean=
# [ 125.11603189 12133.21523472 163.81222321 32.43755536]
# scaler scale=
# [ 163.34480422 27572.94543809 177.6426489 56.82103934]
# svm weights=
# [[ 2.43877659e+00 1.67583047e+00 -5.20527282e-04 3.43925501e-01]]
# svm bias=
# [1.20236999]


def segmatmul_heuristic(inputs: Tensor, type_ptr, weight: Tensor):
num_types = len(type_ptr) - 1
max_num_nodes_per_types = (type_ptr[1:] - type_ptr[:-1]).max()
in_feat = inputs.size(1)
out_feat = weight.size(-1)
# this heuristic was learned with learn_sklearn_heuristic on an A100
x = torch.tensor([num_types, max_num_nodes_per_types, in_feat, out_feat])
scale_mean = torch.tensor(
[125.11603189, 12133.21523472, 163.81222321, 32.43755536])
scale_scale = torch.tensor(
[163.34480422, 27572.94543809, 177.6426489, 56.82103934])
svm_weights = torch.tensor(
[2.43877659e+00, 1.67583047e+00, -5.20527282e-04, 3.43925501e-01])
bias = 1.20236999
x = (x - scale_mean) / scale_scale
return int(x.dot(svm_weights) >= bias)


def group_hetero_graph(edge_index_dict, num_nodes_dict=None):
num_nodes_dict = maybe_num_nodes_dict(edge_index_dict, num_nodes_dict)

Expand Down

0 comments on commit a92cb50

Please sign in to comment.