diff --git a/CHANGELOG.md b/CHANGELOG.md index 851d94219b06..d6c1f51aa437 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,8 +5,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.0.5] - 2022-MM-DD ### Added +- Added inference benchmark suite ([#4915](https://github.com/pyg-team/pytorch_geometric/pull/4915)) - Added a dynamically sized batch sampler for filling a mini-batch with a variable number of samples up to a maximum size ([#4972](https://github.com/pyg-team/pytorch_geometric/pull/4972)) -- Added fine grained options for setting `bias` and `dropout` per layer in the `MLP` model ([#4981](https://github.com/pyg-team/pytorch_geometric/pull/4981) +- Added fine grained options for setting `bias` and `dropout` per layer in the `MLP` model ([#4981](https://github.com/pyg-team/pytorch_geometric/pull/4981)) - Added `EdgeCNN` model ([#4991](https://github.com/pyg-team/pytorch_geometric/pull/4991)) - Added scalable `inference` mode in `BasicGNN` with layer-wise neighbor loading ([#4977](https://github.com/pyg-team/pytorch_geometric/pull/4977)) - Added inference benchmarks ([#4892](https://github.com/pyg-team/pytorch_geometric/pull/4892)) diff --git a/benchmark/inference/hetero_gat.py b/benchmark/inference/hetero_gat.py new file mode 100644 index 000000000000..e63198312f4e --- /dev/null +++ b/benchmark/inference/hetero_gat.py @@ -0,0 +1,44 @@ +import torch +from tqdm import tqdm + +from torch_geometric.nn import GATConv, to_hetero + + +class HeteroGAT(torch.nn.Module): + def __init__(self, metadata, hidden_channels, num_layers, output_channels, + num_heads): + super().__init__() + self.model = to_hetero( + GATForHetero(hidden_channels, num_layers, output_channels, + num_heads), metadata) # TODO: replace by basic_gnn + + @torch.no_grad() + def inference(self, loader, device, progress_bar=False): + self.model.eval() + if progress_bar: + loader = tqdm(loader, desc="Inference") + for batch in loader: + batch = batch.to(device) + self.model(batch.x_dict, batch.edge_index_dict) + + +class GATForHetero(torch.nn.Module): + def __init__(self, hidden_channels, num_layers, out_channels, heads): + super().__init__() + self.convs = torch.nn.ModuleList() + self.convs.append( + GATConv((-1, -1), hidden_channels, heads=heads, + add_self_loops=False)) + for _ in range(num_layers - 2): + self.convs.append( + GATConv((-1, -1), hidden_channels, heads=heads, + add_self_loops=False)) + self.convs.append( + GATConv((-1, -1), out_channels, heads=heads, add_self_loops=False)) + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i < len(self.convs) - 1: + x = x.relu_() + return x diff --git a/benchmark/inference/hetero_sage.py b/benchmark/inference/hetero_sage.py new file mode 100644 index 000000000000..f83c08c7c2a1 --- /dev/null +++ b/benchmark/inference/hetero_sage.py @@ -0,0 +1,38 @@ +import torch +from tqdm import tqdm + +from torch_geometric.nn import SAGEConv, to_hetero + + +class HeteroGraphSAGE(torch.nn.Module): + def __init__(self, metadata, hidden_channels, num_layers, output_channels): + super().__init__() + self.model = to_hetero( + SAGEForHetero(hidden_channels, num_layers, output_channels), + metadata) # TODO: replace by basic_gnn + + @torch.no_grad() + def inference(self, loader, device, progress_bar=False): + self.model.eval() + if progress_bar: + loader = tqdm(loader, desc="Inference") + for batch in loader: + batch = batch.to(device) + self.model(batch.x_dict, batch.edge_index_dict) + + +class SAGEForHetero(torch.nn.Module): + def __init__(self, hidden_channels, num_layers, out_channels): + super().__init__() + self.convs = torch.nn.ModuleList() + self.convs.append(SAGEConv((-1, -1), hidden_channels)) + for i in range(num_layers - 2): + self.convs.append(SAGEConv((-1, -1), hidden_channels)) + self.convs.append(SAGEConv((-1, -1), out_channels)) + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i < len(self.convs) - 1: + x = x.relu_() + return x diff --git a/benchmark/inference/inference_benchmark.py b/benchmark/inference/inference_benchmark.py new file mode 100644 index 000000000000..396932081619 --- /dev/null +++ b/benchmark/inference/inference_benchmark.py @@ -0,0 +1,127 @@ +import argparse +from timeit import default_timer + +import torch +from utils import get_dataset, get_model + +from torch_geometric.loader import NeighborLoader +from torch_geometric.nn import PNAConv + +supported_sets = { + 'ogbn-mag': ['rgat', 'rgcn'], + 'ogbn-products': ['edge_cnn', 'gat', 'gcn', 'pna'], + 'Reddit': ['edge_cnn', 'gat', 'gcn', 'pna'], +} + + +def run(args: argparse.ArgumentParser) -> None: + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + print('BENCHMARK STARTS') + for dataset_name in args.datasets: + assert dataset_name in supported_sets.keys( + ), f"Dataset {dataset_name} isn't supported." + print(f'Dataset: {dataset_name}') + dataset, num_classes = get_dataset(dataset_name, args.root) + data = dataset.to(device) + hetero = True if dataset_name == 'ogbn-mag' else False + mask = ('paper', None) if dataset_name == 'ogbn-mag' else None + degree = None + + inputs_channels = data[ + 'paper'].num_features if dataset_name == 'ogbn-mag' \ + else dataset.num_features + + for model_name in args.models: + if model_name not in supported_sets[dataset_name]: + print(f'Configuration of {dataset_name} + {model_name} ' + f'not supported. Skipping.') + continue + print(f'Evaluation bench for {model_name}:') + + for batch_size in args.eval_batch_sizes: + if not hetero: + subgraph_loader = NeighborLoader( + data, + num_neighbors=[-1], # layer-wise inference + input_nodes=mask, + batch_size=batch_size, + shuffle=False, + num_workers=args.num_workers, + ) + + for layers in args.num_layers: + if hetero: + subgraph_loader = NeighborLoader( + data, + num_neighbors=[args.hetero_num_neighbors] * + layers, # batch-wise inference + input_nodes=mask, + batch_size=batch_size, + shuffle=False, + num_workers=args.num_workers, + ) + + for hidden_channels in args.num_hidden_channels: + print( + '-----------------------------------------------') + print( + f'Batch size={batch_size}, ' + f'Layers amount={layers}, ' + f'Num_neighbors={subgraph_loader.num_neighbors}, ' + f'Hidden features size={hidden_channels}') + params = { + 'inputs_channels': inputs_channels, + 'hidden_channels': hidden_channels, + 'output_channels': num_classes, + 'num_heads': args.num_heads, + 'num_layers': layers, + } + + if model_name == 'pna': + if degree is None: + degree = PNAConv.get_degree_histogram( + subgraph_loader) + print(f'Calculated degree for {dataset_name}.') + params['degree'] = degree + + model = get_model( + model_name, params, + metadata=data.metadata() if hetero else None) + model = model.to(device) + model.eval() + + start = default_timer() + model.inference(subgraph_loader, device, + progress_bar=True) + stop = default_timer() + print(f'Inference time={stop-start:.3f} seconds\n') + + +if __name__ == '__main__': + argparser = argparse.ArgumentParser('GNN inference benchmark') + argparser.add_argument('--datasets', nargs='+', + default=['ogbn-mag', 'ogbn-products', + 'Reddit'], type=str) + argparser.add_argument( + '--models', nargs='+', + default=['edge_cnn', 'gat', 'gcn', 'pna', 'rgat', 'rgcn'], type=str) + argparser.add_argument('--root', default='../../data', type=str, + help='relative path to look for the datasets') + argparser.add_argument('--eval-batch-sizes', nargs='+', + default=[512, 1024, 2048, 4096, 8192], type=int) + argparser.add_argument('--num-layers', nargs='+', default=[2, 3], type=int) + argparser.add_argument('--num-hidden-channels', nargs='+', + default=[64, 128, 256], type=int) + argparser.add_argument( + '--num-heads', default=2, type=int, + help='number of hidden attention heads, applies only for gat and rgat') + argparser.add_argument( + '--hetero-num-neighbors', default=-1, type=int, + help='number of neighbors to sample per layer for hetero workloads') + argparser.add_argument('--num-workers', default=2, type=int) + + args = argparser.parse_args() + + run(args) diff --git a/benchmark/inference/utils.py b/benchmark/inference/utils.py new file mode 100644 index 000000000000..25a8eae0284a --- /dev/null +++ b/benchmark/inference/utils.py @@ -0,0 +1,60 @@ +import os.path as osp + +from hetero_gat import HeteroGAT +from hetero_sage import HeteroGraphSAGE +from ogb.nodeproppred import PygNodePropPredDataset + +import torch_geometric.transforms as T +from torch_geometric.datasets import OGB_MAG, Reddit +from torch_geometric.nn.models.basic_gnn import GAT, GCN, PNA, EdgeCNN + +models_dict = { + 'edge_cnn': EdgeCNN, + 'gat': GAT, + 'gcn': GCN, + 'pna': PNA, + 'rgat': HeteroGAT, + 'rgcn': HeteroGraphSAGE, +} + + +def get_dataset(name, root): + path = osp.join(osp.dirname(osp.realpath(__file__)), root, name) + if name == 'ogbn-mag': + transform = T.ToUndirected(merge=True) + dataset = OGB_MAG(root=path, preprocess='metapath2vec', + transform=transform) + elif name == 'ogbn-products': + dataset = PygNodePropPredDataset('ogbn-products', root=path) + elif name == 'Reddit': + dataset = Reddit(root=path) + + return dataset[0], dataset.num_classes + + +def get_model(name, params, metadata=None): + Model = models_dict.get(name, None) + assert Model is not None, f'Model {name} not supported!' + + if name == 'rgat': + return Model(metadata, params['hidden_channels'], params['num_layers'], + params['output_channels'], params['num_heads']) + + if name == 'rgcn': + return Model(metadata, params['hidden_channels'], params['num_layers'], + params['output_channels']) + + if name == 'gat': + return Model(params['inputs_channels'], params['hidden_channels'], + params['num_layers'], params['output_channels'], + heads=params['num_heads']) + + if name == 'pna': + return Model(params['inputs_channels'], params['hidden_channels'], + params['num_layers'], params['output_channels'], + aggregators=['mean', 'min', 'max', 'std'], + scalers=['identity', 'amplification', + 'attenuation'], deg=params['degree']) + + return Model(params['inputs_channels'], params['hidden_channels'], + params['num_layers'], params['output_channels']) diff --git a/test/nn/conv/test_pna_conv.py b/test/nn/conv/test_pna_conv.py index e7b37c082ef7..33acd5d88c91 100644 --- a/test/nn/conv/test_pna_conv.py +++ b/test/nn/conv/test_pna_conv.py @@ -1,6 +1,8 @@ import torch from torch_sparse import SparseTensor +from torch_geometric.data import Data +from torch_geometric.loader import DataLoader, NeighborLoader from torch_geometric.nn import PNAConv from torch_geometric.testing import is_full_test @@ -32,3 +34,38 @@ def test_pna_conv(): t = '(Tensor, SparseTensor, OptTensor) -> Tensor' jit = torch.jit.script(conv.jittable(t)) assert torch.allclose(jit(x, adj.t()), out, atol=1e-6) + + +def test_pna_conv_get_degree_histogram(): + edge_index = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]]) + data = Data(num_nodes=5, edge_index=edge_index) + loader = NeighborLoader( + data, + num_neighbors=[-1], + input_nodes=None, + batch_size=5, + shuffle=False, + ) + deg_hist = PNAConv.get_degree_histogram(loader) + deg_hist_ref = torch.tensor([1, 2, 1, 1]) + assert torch.equal(deg_hist_ref, deg_hist) + + edge_index_1 = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]]) + edge_index_2 = torch.tensor([[1, 1, 2, 2, 0, 3, 3], [2, 3, 3, 1, 1, 0, 2]]) + edge_index_3 = torch.tensor([[1, 3, 2, 0, 0, 4, 2], [2, 0, 4, 1, 1, 0, 3]]) + edge_index_4 = torch.tensor([[0, 1, 2, 4, 0, 1, 3], [2, 3, 3, 1, 1, 0, 2]]) + + data_1 = Data(num_nodes=5, + edge_index=edge_index_1) # deg_hist = [1, 2 ,1 ,1] + data_2 = Data(num_nodes=5, edge_index=edge_index_2) # deg_hist = [1, 1, 3] + data_3 = Data(num_nodes=5, edge_index=edge_index_3) # deg_hist = [0, 3, 2] + data_4 = Data(num_nodes=5, edge_index=edge_index_4) # deg_hist = [1, 1, 3] + + loader = DataLoader( + [data_1, data_2, data_3, data_4], + batch_size=1, + shuffle=False, + ) + deg_hist = PNAConv.get_degree_histogram(loader) + deg_hist_ref = torch.tensor([3, 7, 9, 1]) + assert torch.equal(deg_hist_ref, deg_hist) diff --git a/torch_geometric/nn/conv/pna_conv.py b/torch_geometric/nn/conv/pna_conv.py index d5ee0f9e144b..e4269c804a45 100644 --- a/torch_geometric/nn/conv/pna_conv.py +++ b/torch_geometric/nn/conv/pna_conv.py @@ -8,6 +8,7 @@ from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, OptTensor +from torch_geometric.utils import degree from ..inits import reset @@ -169,3 +170,19 @@ def __repr__(self): return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, towers={self.towers}, ' f'edge_dim={self.edge_dim})') + + @staticmethod + def get_degree_histogram(loader) -> Tensor: + max_degree = 0 + for data in loader: + d = degree(data.edge_index[1], num_nodes=data.num_nodes, + dtype=torch.long) + max_degree = max(max_degree, int(d.max())) + # Compute the in-degree histogram tensor + deg_histogram = torch.zeros(max_degree + 1, dtype=torch.long) + for data in loader: + d = degree(data.edge_index[1], num_nodes=data.num_nodes, + dtype=torch.long) + deg_histogram += torch.bincount(d, minlength=deg_histogram.numel()) + + return deg_histogram diff --git a/torch_geometric/nn/models/basic_gnn.py b/torch_geometric/nn/models/basic_gnn.py index 3334d4385db0..9b8b314f8ba0 100644 --- a/torch_geometric/nn/models/basic_gnn.py +++ b/torch_geometric/nn/models/basic_gnn.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Linear, ModuleList +from tqdm import tqdm from torch_geometric.loader import NeighborLoader from torch_geometric.nn.conv import ( @@ -167,7 +168,8 @@ def forward(self, x: Tensor, edge_index: Adj, *args, **kwargs) -> Tensor: @torch.no_grad() def inference(self, loader: NeighborLoader, - device: Optional[torch.device] = None) -> Tensor: + device: Optional[torch.device] = None, + progress_bar: bool = False) -> Tensor: r"""Performs layer-wise inference on large-graphs using :class:`~torch_geometric.loader.NeighborLoader`. :class:`~torch_geometric.loader.NeighborLoader` should sample the the @@ -182,6 +184,9 @@ def inference(self, loader: NeighborLoader, assert len(loader.num_neighbors) == 1 assert not self.training # assert not loader.shuffle # TODO (matthias) does not work :( + if progress_bar: + pbar = tqdm(total=len(self.convs) * len(loader)) + pbar.set_description('Inference') x_all = loader.data.x.cpu() loader.data.n_id = torch.arange(x_all.size(0)) @@ -194,6 +199,8 @@ def inference(self, loader: NeighborLoader, x = self.convs[i](x, edge_index)[:batch.batch_size] if i == self.num_layers - 1 and self.jk_mode is None: xs.append(x.cpu()) + if progress_bar: + pbar.update(1) continue if self.act is not None and self.act_first: x = self.act(x) @@ -204,8 +211,11 @@ def inference(self, loader: NeighborLoader, if i == self.num_layers - 1 and hasattr(self, 'lin'): x = self.lin(x) xs.append(x.cpu()) + if progress_bar: + pbar.update(1) x_all = torch.cat(xs, dim=0) - + if progress_bar: + pbar.close() del loader.data.n_id return x_all