Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] Adding inference benchmark suite #4915

Merged
merged 16 commits into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions benchmark/inference/edgeconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
import torch.nn.functional as F
from torch.nn import Linear as Lin
from torch.nn import ReLU
from torch.nn import Sequential as Seq
from tqdm import tqdm

from torch_geometric.nn import EdgeConv


class EdgeConvNet(torch.nn.Module):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, input_channels, hidden_channels, out_channels,
num_layers):
super().__init__()
nn_in = Seq(Lin(2 * input_channels, hidden_channels), ReLU(),
Lin(hidden_channels, hidden_channels))
nn_hid = Seq(Lin(2 * hidden_channels, hidden_channels), ReLU(),
Lin(hidden_channels, hidden_channels))
nn_out = Seq(Lin(2 * hidden_channels, hidden_channels), ReLU(),
Lin(hidden_channels, out_channels))
self.convs = torch.nn.ModuleList()
self.convs.append(EdgeConv(nn_in))
for _ in range(num_layers - 2):
self.convs.append(EdgeConv(nn_hid))
self.convs.append(EdgeConv(nn_out))

def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = x.relu_()
return x

@torch.no_grad()
def inference(self, subgraph_loader, device):
for batch in tqdm(subgraph_loader):
batch = batch.to(device)
batch_size = batch.batch_size
out = self(batch.x, batch.edge_index)[:batch_size]
50 changes: 50 additions & 0 deletions benchmark/inference/gat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import torch.nn.functional as F
from torch.nn import Linear
from tqdm import tqdm

from torch_geometric.nn import GATConv


class GATBlock(torch.nn.Module):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, in_channels, out_channels, heads, last_layer=False,
**conv_kwargs):
super().__init__()

self.conv = GATConv(in_channels, out_channels, heads, **conv_kwargs)
self.skip = Linear(
in_channels, out_channels if last_layer else out_channels * heads)
self.last_layer = last_layer

def forward(self, x, edge_index):
x = self.conv(x, edge_index)
# TODO: how to use skip connection with NeighborLoader?
# x = x + self.skip(?)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
return x if self.last_layer else F.elu(x)


class GATNet(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads,
num_layers):
super().__init__()

self.layers = torch.nn.ModuleList()
self.layers.append(GATBlock(in_channels, hidden_channels, heads))
for _ in range(num_layers - 2):
self.layers.append(
GATBlock(hidden_channels * heads, hidden_channels, heads))
self.layers.append(
GATBlock(hidden_channels * heads, out_channels, heads,
last_layer=True, concat=False))

def forward(self, x, edge_index):
for layer in self.layers:
x = layer(x, edge_index)
return x

@torch.no_grad()
def inference(self, subgraph_loader, device):
for batch in tqdm(subgraph_loader):
batch = batch.to(device)
batch_size = batch.batch_size
out = self(batch.x, batch.edge_index)[:batch_size]
30 changes: 30 additions & 0 deletions benchmark/inference/gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm

from torch_geometric.nn import GCNConv


class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super(GCN, self).__init__()

self.convs = torch.nn.ModuleList()
self.convs.append(GCNConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.convs.append(GCNConv(hidden_channels, out_channels))

def forward(self, x, edge_index):
for conv in self.convs[:-1]:
x = conv(x, edge_index)
x = F.relu(x)
x = self.convs[-1](x, edge_index)
return x

@torch.no_grad()
def inference(self, subgraph_loader, device):
for batch in tqdm(subgraph_loader):
batch = batch.to(device)
batch_size = batch.batch_size
out = self(batch.x, batch.edge_index)[:batch_size]
42 changes: 42 additions & 0 deletions benchmark/inference/graphsage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import torch
from tqdm import tqdm

from torch_geometric.nn import SAGEConv, to_hetero


class SAGE_HETERO:
def __init__(self, hidden_channels, output_channels, num_layers) -> None:
self.model = None
self.hidden_channels = hidden_channels
self.output_channels = output_channels
self.num_layers = num_layers

def create_hetero(self, metadata):
model = SAGE_FOR_HETERO(self.hidden_channels, self.output_channels,
self.num_layers)
self.model = to_hetero(model, metadata, aggr='sum')

def inference(self, loader, device):
self.model.eval()
for batch in tqdm(loader):
batch = batch.to(device)
batch_size = batch['paper'].batch_size
out = self.model(batch.x_dict,
batch.edge_index_dict)['paper'][:batch_size]


class SAGE_FOR_HETERO(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_layers):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv((-1, -1), hidden_channels))
for i in range(num_layers - 2):
self.convs.append(SAGEConv((-1, -1), hidden_channels))
self.convs.append(SAGEConv((-1, -1), out_channels))

def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = x.relu_()
return x
128 changes: 128 additions & 0 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import argparse
import copy
from timeit import default_timer

import torch
from ogb.nodeproppred import PygNodePropPredDataset
from utils import get_dataset, get_degree, get_model

from torch_geometric.loader import NeighborLoader

supported_sets = {
'ogbn-mag': ['rgat', 'rgcn'],
'reddit': ['edge_conv', 'gat', 'gcn', 'pna_conv'],
'ogbn-products': ['edge_conv', 'gat', 'gcn', 'pna_conv'],
}


def run(args: argparse.ArgumentParser) -> None:

print('BENCHMARK STARTS')
if args.pure_gnn_mode:
print('PURE GNN MODE ACTIVATED')
for dataset_name in args.datasets:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
print(f'Dataset: {dataset_name}')
dataset = get_dataset(
dataset_name, args.root, PygNodePropPredDataset
if dataset_name == 'ogbn-products' else None)

mask = ('paper', None) if dataset_name == 'ogbn-mag' else None

data = dataset[0].to(args.device)
inputs_channels = data.x_dict['paper'].size(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inputs_channels = data.x_dict['paper'].size(
inputs_channels = data['paper'].num_features

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

-1) if dataset_name == 'ogbn-mag' else dataset.num_features

for model_name in args.models:
if model_name not in supported_sets[dataset_name]:
print(f'Configuration of {dataset_name} + {model_name} '
f'not supported. Skipping.')
continue
print(f'Evaluation bench for {model_name}:')
if model_name == 'pna_conv':
loader = NeighborLoader(
copy.copy(data),
num_neighbors=[-1],
input_nodes=mask,
batch_size=1024,
shuffle=False,
num_workers=args.num_workers,
)
degree = get_degree(loader)
rusty1s marked this conversation as resolved.
Show resolved Hide resolved

for batch_size in args.eval_batch_sizes:
subgraph_loader = NeighborLoader(
copy.copy(data),
num_neighbors=[-1],
input_nodes=mask,
batch_size=batch_size,
shuffle=False,
num_workers=args.num_workers,
)
subgraph_loader.data.n_id = torch.arange(data.num_nodes)

for layers in args.num_layers:
for hidden_channels in args.num_hidden_channels:
print(
'-----------------------------------------------')
print(f'Batch size={batch_size}, '
f'Layers amount={layers}, '
f'Hidden features size={hidden_channels}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we can print num_neighbors instead of layers, since this is more relevant to understand the output size of a sampling pass.

params = {
'inputs_channels': inputs_channels,
'hidden_channels': hidden_channels,
'output_channels': dataset.num_classes,
'num_heads': args.num_heads,
'num_layers': layers,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are standardizing behind num_neighbors, here we can just set 'num_layers' = len(num_neighbors).

}
if model_name == 'pna_conv':
params['degree'] = degree

model = get_model(
model_name, params, metadata=data.metadata()
if dataset_name == 'ogbn-mag' else None)

if args.pure_gnn_mode:
prebatched_samples = []
for i, batch in enumerate(subgraph_loader):
if i == args.prebatched_samples:
break
prebatched_samples.append(batch)
subgraph_loader = prebatched_samples

start = default_timer()
model.inference(subgraph_loader, args.device)
stop = default_timer()
print(f'Inference time={stop-start:.3f}\n')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Units?



if __name__ == '__main__':
argparser = argparse.ArgumentParser('GNN inference benchmark')

argparser.add_argument('--device', default='cpu', type=str)
mszarma marked this conversation as resolved.
Show resolved Hide resolved
argparser.add_argument(
'--pure-gnn-mode', action='store_true',
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
help='turn on pure gnn efficiency bench - firstly prepare batches')
argparser.add_argument('--prebatched_samples', default=3, type=int,
help='number of preloaded batches in pure_gnn mode')
argparser.add_argument('--datasets', nargs='+',
default=['ogbn-mag', 'ogbn-products',
'reddit'], type=str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we automatically obtain this from supported_sets.keys()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please look into the comment above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1.

argparser.add_argument(
'--models', nargs='+',
default=['edge_conv', 'gat', 'gcn', 'pna_conv', 'rgat',
'rgcn'], type=str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we automatically obtain this from supported_sets.values()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea of that default param is to have an option to change the performed bench suites during some checking etc. - for user's convenience who can do it in CLI or change directly in script - I personally would prefer to have it written explicitly in the list - pulling from supported_sets would not give such a flexibility.
WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would need to get the union of all values. +1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably okay to leave it as it is for now :)

argparser.add_argument('--root', default='../../data', type=str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document this parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

argparser.add_argument('--eval-batch-sizes', nargs='+',
default=[512, 1024, 2048, 4096, 8192], type=int)
argparser.add_argument('--num-layers', nargs='+', default=[1, 2, 3],
type=int)
argparser.add_argument('--num-hidden-channels', nargs='+',
default=[64, 128, 256], type=int)
argparser.add_argument(
'--num-heads', default=3, type=int,
help='number of hidden attention heads, applies only for gat and rgat')
argparser.add_argument('--num-workers', default=2, type=int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a reasonable default (instead of cpu_count // 2, for example?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shortly: As for now - for NeighborLoader should be good enough - we can always changed it when we will do some more analysis of workloads and while improving the software stack.
More: Dataloading is tightly coupled with model inference/train loop - the dataloader "producing" efficiency interact the model "consuming" efficiency. In general empirical experience shows that X num_workers can be optimal for cpu cores in range Y to Z - the accurate values may wary and depends things like HW (e.g. single vs multi sockets, memory) and OS/env settings.


args = argparser.parse_args()

run(args)
37 changes: 37 additions & 0 deletions benchmark/inference/pna.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from tqdm import tqdm

from torch_geometric.nn import PNAConv


class PNANet(torch.nn.Module):
def __init__(self, input_channels, hidden_channels, out_channels,
num_layers, degree):
super().__init__()
self.aggregators = ['mean', 'min', 'max', 'std']
self.scalers = ['identity', 'amplification', 'attenuation']
self.convs = torch.nn.ModuleList()
self.convs.append(
PNAConv(input_channels, hidden_channels, self.aggregators,
self.scalers, degree))
for i in range(num_layers - 2):
self.convs.append(
PNAConv(hidden_channels, hidden_channels, self.aggregators,
self.scalers, degree))
self.convs.append(
PNAConv(hidden_channels, out_channels, self.aggregators,
self.scalers, degree))

def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = x.relu_()
return x

@torch.no_grad()
def inference(self, subgraph_loader, device):
for batch in tqdm(subgraph_loader):
batch = batch.to(device)
batch_size = batch.batch_size
out = self(batch.x, batch.edge_index)[:batch_size]
50 changes: 50 additions & 0 deletions benchmark/inference/rgat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm

from torch_geometric.nn import GATConv, to_hetero


class GAT_HETERO:
def __init__(self, hidden_channels, output_channels, num_layers,
num_heads) -> None:
self.model = None
self.hidden_channels = hidden_channels
self.output_channels = output_channels
self.num_layers = num_layers
self.num_heads = num_heads

def create_hetero(self, metadata):
model = GAT_FOR_HETERO(self.hidden_channels, self.output_channels,
self.num_layers, self.num_heads)
self.model = to_hetero(model, metadata, aggr='sum')

def inference(self, loader, device):
self.model.eval()
for batch in tqdm(loader):
batch = batch.to(device)
batch_size = batch['paper'].batch_size
out = self.model(batch.x_dict,
batch.edge_index_dict)['paper'][:batch_size]


class GAT_FOR_HETERO(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_layers, heads):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(
GATConv((-1, -1), hidden_channels, heads=heads,
add_self_loops=False))
for _ in range(num_layers - 2):
self.convs.append(
GATConv((-1, -1), hidden_channels, heads=heads,
add_self_loops=False))
self.convs.append(
GATConv((-1, -1), out_channels, heads=heads, add_self_loops=False))

def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = x.relu_()
return x
Loading