Skip to content

Commit

Permalink
refactor - hetero models, add get_degree_histogram test , progress_bat
Browse files Browse the repository at this point in the history
  • Loading branch information
mszarma committed Jul 19, 2022
1 parent 81c4d01 commit fa7beaf
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 65 deletions.
33 changes: 11 additions & 22 deletions benchmark/inference/hetero_gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,27 @@
from torch_geometric.nn import GATConv, to_hetero


class HETERO_GAT:
def __init__(self, hidden_channels, num_layers, output_channels,
num_heads) -> None:
self.model = None
self.hidden_channels = hidden_channels
self.output_channels = output_channels
self.num_layers = num_layers
self.num_heads = num_heads
class HeteroGAT(torch.nn.Module):
def __init__(self, metadata, hidden_channels, num_layers, output_channels,
num_heads):
super().__init__()
self.model = to_hetero(
GATForHetero(hidden_channels, num_layers, output_channels,
num_heads), metadata)
self.training = False

def create_hetero(self, metadata):
model = GAT_FOR_HETERO(self.hidden_channels, self.output_channels,
self.num_layers, self.num_heads)
self.model = to_hetero(model, metadata, aggr='sum')

def to(self, device):
self.model = self.model.to(device)
return self

@torch.inference_mode()
def inference(self, loader, device, progress_bar=False):
self.model.eval()
if progress_bar:
loader = tqdm(loader)
for batch in loader:
batch = batch.to(device)
batch_size = batch['paper'].batch_size
self.model(batch.x_dict,
batch.edge_index_dict)['paper'][:batch_size]
self.model(batch.x_dict, batch.edge_index_dict)


class GAT_FOR_HETERO(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_layers, heads):
class GATForHetero(torch.nn.Module):
def __init__(self, hidden_channels, num_layers, out_channels, heads):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(
Expand Down
30 changes: 10 additions & 20 deletions benchmark/inference/hetero_sage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,26 @@
from torch_geometric.nn import SAGEConv, to_hetero


class HETERO_SAGE:
def __init__(self, hidden_channels, num_layers, output_channels) -> None:
self.model = None
self.hidden_channels = hidden_channels
self.output_channels = output_channels
self.num_layers = num_layers
class HeteroGraphSAGE(torch.nn.Module):
def __init__(self, metadata, hidden_channels, num_layers, output_channels):
super().__init__()
self.model = to_hetero(
SAGEForHetero(hidden_channels, num_layers, output_channels),
metadata)
self.training = False

def create_hetero(self, metadata):
model = SAGE_FOR_HETERO(self.hidden_channels, self.output_channels,
self.num_layers)
self.model = to_hetero(model, metadata, aggr='sum')

def to(self, device):
self.model = self.model.to(device)
return self

@torch.inference_mode()
def inference(self, loader, device, progress_bar=False):
self.model.eval()
if progress_bar:
loader = tqdm(loader)
for batch in loader:
batch = batch.to(device)
batch_size = batch['paper'].batch_size
self.model(batch.x_dict,
batch.edge_index_dict)['paper'][:batch_size]
self.model(batch.x_dict, batch.edge_index_dict)


class SAGE_FOR_HETERO(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_layers):
class SAGEForHetero(torch.nn.Module):
def __init__(self, hidden_channels, num_layers, out_channels):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv((-1, -1), hidden_channels))
Expand Down
3 changes: 2 additions & 1 deletion benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def run(args: argparse.ArgumentParser) -> None:

if model_name == 'pna_conv':
if degree is None:
degree = PNAConv.get_degree(subgraph_loader)
degree = PNAConv.get_degree_histogram(
subgraph_loader)
print(f'Calculated degree for {dataset_name}.')
params['degree'] = degree

Expand Down
24 changes: 12 additions & 12 deletions benchmark/inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os.path as osp

from hetero_gat import HETERO_GAT
from hetero_sage import HETERO_SAGE
from hetero_gat import HeteroGAT
from hetero_sage import HeteroGraphSAGE
from ogb.nodeproppred import PygNodePropPredDataset

import torch_geometric.transforms as T
Expand All @@ -13,8 +13,8 @@
'gat': GAT,
'gcn': GCN,
'pna_conv': PNA,
'rgat': HETERO_GAT,
'rgcn': HETERO_SAGE,
'rgat': HeteroGAT,
'rgcn': HeteroGraphSAGE,
}


Expand All @@ -40,14 +40,14 @@ def get_model(name, params, metadata=None):
except KeyError:
print(f'Model {name} not supported!')

if name in ['rgat', 'rgcn']:
if name == 'rgat':
model = model_type(params['hidden_channels'], params['num_layers'],
params['output_channels'], params['num_heads'])
elif name == 'rgcn':
model = model_type(params['hidden_channels'], params['num_layers'],
params['output_channels'])
model.create_hetero(metadata)
if name == 'rgat':
model = model_type(metadata, params['hidden_channels'],
params['num_layers'], params['output_channels'],
params['num_heads'])

elif name == 'rgcn':
model = model_type(metadata, params['hidden_channels'],
params['num_layers'], params['output_channels'])

elif name == 'gat':
kwargs = {}
Expand Down
39 changes: 39 additions & 0 deletions test/nn/conv/test_pna_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import PNAConv
from torch_geometric.testing import is_full_test

Expand Down Expand Up @@ -32,3 +34,40 @@ def test_pna_conv():
t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)


def test_pna_conv_get_degree_histogram():
edge_index = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]])
x = torch.randn(5, 16)
data = Data(x=x, edge_index=edge_index)
loader = NeighborLoader(
data,
num_neighbors=[-1],
input_nodes=None,
batch_size=5,
shuffle=False,
)
deg_hist = PNAConv.get_degree_histogram(loader)
deg_hist_ref = torch.tensor([1, 2, 1, 1])
assert torch.equal(deg_hist_ref, deg_hist)

edge_index_1 = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]])
edge_index_2 = torch.tensor([[1, 1, 2, 2, 0, 3, 3], [2, 3, 3, 1, 1, 0, 2]])
edge_index_3 = torch.tensor([[1, 3, 2, 0, 0, 4, 2], [2, 0, 4, 1, 1, 0, 3]])
edge_index_4 = torch.tensor([[0, 1, 2, 4, 0, 1, 3], [2, 3, 3, 1, 1, 0, 2]])

x = torch.randn(5, 16)

data_1 = Data(x=x, edge_index=edge_index_1) # deg_hist = [1, 1, 3]
data_2 = Data(x=x, edge_index=edge_index_2) # deg_hist = [1, 2 ,1 ,1]
data_3 = Data(x=x, edge_index=edge_index_3) # deg_hist = [0, 3, 2]
data_4 = Data(x=x, edge_index=edge_index_4) # deg_hist = [1, 1, 3]

loader = DataLoader(
[data_1, data_2, data_3, data_4],
batch_size=1,
shuffle=False,
)
deg_hist = PNAConv.get_degree_histogram(loader)
deg_hist_ref = torch.tensor([3, 7, 9, 1])
assert torch.equal(deg_hist_ref, deg_hist)
12 changes: 6 additions & 6 deletions torch_geometric/nn/conv/pna_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,17 @@ def __repr__(self):
f'edge_dim={self.edge_dim})')

@staticmethod
def get_degree(loader):
max_degree = -1
def get_degree_histogram(loader) -> Tensor:
max_degree = 0
for data in loader:
d = degree(data.edge_index[1], num_nodes=data.num_nodes,
dtype=torch.long)
max_degree = max(max_degree, int(d.max()))

# Compute the in-degree histogram tensor
deg = torch.zeros(max_degree + 1, dtype=torch.long)
deg_histogram = torch.zeros(max_degree + 1, dtype=torch.long)
for data in loader:
d = degree(data.edge_index[1], num_nodes=data.num_nodes,
dtype=torch.long)
deg += torch.bincount(d, minlength=deg.numel())
return deg
deg_histogram += torch.bincount(d, minlength=deg_histogram.numel())

return deg_histogram
10 changes: 6 additions & 4 deletions torch_geometric/nn/models/basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ def inference(self, loader: NeighborLoader,
assert not self.training
# assert not loader.shuffle # TODO (matthias) does not work :(
if progress_bar:
pbar = tqdm(total=len(self.convs))
pbar.set_description('Evaluating per layer')
pbar = tqdm(total=len(self.convs) * len(loader))
pbar.set_description('Evaluating')

x_all = loader.data.x.cpu()
loader.data.n_id = torch.arange(x_all.size(0))
Expand All @@ -199,6 +199,8 @@ def inference(self, loader: NeighborLoader,
x = self.convs[i](x, edge_index)[:batch.batch_size]
if i == self.num_layers - 1 and self.jk_mode is None:
xs.append(x.cpu())
if progress_bar:
pbar.update(1)
continue
if self.act is not None and self.act_first:
x = self.act(x)
Expand All @@ -209,9 +211,9 @@ def inference(self, loader: NeighborLoader,
if i == self.num_layers - 1 and hasattr(self, 'lin'):
x = self.lin(x)
xs.append(x.cpu())
if progress_bar:
pbar.update(1)
x_all = torch.cat(xs, dim=0)
if progress_bar:
pbar.update(1)
if progress_bar:
pbar.close()
del loader.data.n_id
Expand Down

0 comments on commit fa7beaf

Please sign in to comment.