Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] Adding inference benchmark suite #4915

Merged
merged 16 commits into from
Jul 21, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions benchmark/inference/hetero_gat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
from tqdm import tqdm

from torch_geometric.nn import GATConv, to_hetero


class HeteroGAT(torch.nn.Module):
def __init__(self, metadata, hidden_channels, num_layers, output_channels,
num_heads):
super().__init__()
self.model = to_hetero(
GATForHetero(hidden_channels, num_layers, output_channels,
num_heads), metadata)
self.training = False

@torch.no_grad()
def inference(self, loader, device, progress_bar=False):
self.model.eval()
if progress_bar:
loader = tqdm(loader)
for batch in loader:
batch = batch.to(device)
self.model(batch.x_dict, batch.edge_index_dict)


class GATForHetero(torch.nn.Module):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, hidden_channels, num_layers, out_channels, heads):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(
GATConv((-1, -1), hidden_channels, heads=heads,
add_self_loops=False))
for _ in range(num_layers - 2):
self.convs.append(
GATConv((-1, -1), hidden_channels, heads=heads,
add_self_loops=False))
self.convs.append(
GATConv((-1, -1), out_channels, heads=heads, add_self_loops=False))

def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = x.relu_()
return x
39 changes: 39 additions & 0 deletions benchmark/inference/hetero_sage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from tqdm import tqdm

from torch_geometric.nn import SAGEConv, to_hetero


class HeteroGraphSAGE(torch.nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def __init__(self, metadata, hidden_channels, num_layers, output_channels):
super().__init__()
self.model = to_hetero(
SAGEForHetero(hidden_channels, num_layers, output_channels),
metadata)
self.training = False

@torch.no_grad()
def inference(self, loader, device, progress_bar=False):
mszarma marked this conversation as resolved.
Show resolved Hide resolved
self.model.eval()
if progress_bar:
loader = tqdm(loader)
for batch in loader:
batch = batch.to(device)
self.model(batch.x_dict, batch.edge_index_dict)


class SAGEForHetero(torch.nn.Module):
def __init__(self, hidden_channels, num_layers, out_channels):
super().__init__()
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv((-1, -1), hidden_channels))
for i in range(num_layers - 2):
self.convs.append(SAGEConv((-1, -1), hidden_channels))
self.convs.append(SAGEConv((-1, -1), out_channels))

def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = x.relu_()
return x
126 changes: 126 additions & 0 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import argparse
import copy
from timeit import default_timer

import torch
from utils import get_dataset, get_model

from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import PNAConv

supported_sets = {
'ogbn-mag': ['rgat', 'rgcn'],
'reddit': ['edge_conv', 'gat', 'gcn', 'pna_conv'],
'ogbn-products': ['edge_conv', 'gat', 'gcn', 'pna_conv'],
}


def run(args: argparse.ArgumentParser) -> None:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
progress_bar = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not set it directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


print('BENCHMARK STARTS')
for dataset_name in args.datasets:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
print(f'Dataset: {dataset_name}')
dataset = get_dataset(dataset_name, args.root)

hetero = True if dataset_name == 'ogbn-mag' else False
mask = ('paper', None) if hetero else None
degree = None

data = dataset[0].to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we return dataset[0] from get_dataset, so this logic can be handled properly for datasets that may override __getitem__?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

inputs_channels = data.x_dict['paper'].size(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
inputs_channels = data.x_dict['paper'].size(
inputs_channels = data['paper'].num_features

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

-1) if hetero else dataset.num_features
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 29 and 33-34 seem specific to ogbn-mag (instead of hetero), and will likely break if we add more heterogeneous datasets in the future. Can we condition on if dataset_name == 'ogbn-mag' instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Done


for model_name in args.models:
if model_name not in supported_sets[dataset_name]:
print(f'Configuration of {dataset_name} + {model_name} '
f'not supported. Skipping.')
continue
print(f'Evaluation bench for {model_name}:')

for batch_size in args.eval_batch_sizes:
if not hetero:
subgraph_loader = NeighborLoader(
copy.copy(data),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to copy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed - Wasn't sure about how two loaders will handle the same data so did a shallow copy base on some PyG examples.

num_neighbors=[-1],
input_nodes=mask,
batch_size=batch_size,
shuffle=False,
num_workers=args.num_workers,
)
subgraph_loader.data.n_id = torch.arange(data.num_nodes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be dropped IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


for layers in args.num_layers:
if hetero:
subgraph_loader = NeighborLoader(
copy.copy(data),
num_neighbors=[args.hetero_num_neighbors] * layers,
input_nodes=mask,
batch_size=batch_size,
shuffle=False,
num_workers=args.num_workers,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand the differences for NeighborLoader between homogeneous and heterogeneous graphs, as well as the interplay between num_neighbors and num_layers. For the purposes of this benchmark, it should be the case that len(num_neighbors) == len(num_layers) (for both homogeneous and heterogeneous graphs).

If we want to be able to specify the number of neighbors for each layer, we can consolidate this logic and just have num_neighbors = args.num_neighbors. Does that make sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mananshah99 Maybe i will point out few things to make the situation more clear:

  • num_layers is parameter for benchmark, it is list of layer sizes to bench for the models -> layer(s) is passed to model to define the number of layers to create in model in basic_gnn and hetero_* . So we are benching in the loop models created with different number of layers separately.
  • Inference for basic_gnn is performed "layer-wise" and inference for hetero is performed "batch-wise". In case of layer-wise approach we always working only on the nearest neighborhood (neighborloader num_neighbors "len" is always 1) . In case of batch-wise we need to define neighborloader num_neighbor's as deep as the layers size.
  • The hetero_num_neighbors args was added only because right now inference for hetero in batch-wise mode is so long. Added such arg to have an option for user/CI to easily change the num_neighbors to do some smaller bench.
  • Creation of loader's takes time and we are performing in the loop many benchmarks for different setups so to save the time for homogeneous workloads where we don't depend on layer size we create loader before the layers size loop but for hetero models we need to know the layer size in the moment of creating NeighborLoader object so we create it in the num_layers loop.

Let me know if that explains the situation and approaches.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For homogeneous graphs, we do inference via layer-wise computation. As such, I think the code is correct. We should nonetheless document this properly here to avoid confusion.

subgraph_loader.data.n_id = torch.arange(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Drop as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

data.num_nodes)

for hidden_channels in args.num_hidden_channels:
print(
'-----------------------------------------------')
print(f'Batch size={batch_size}, '
f'Layers amount={layers}, '
f'Hidden features size={hidden_channels}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we can print num_neighbors instead of layers, since this is more relevant to understand the output size of a sampling pass.

params = {
'inputs_channels': inputs_channels,
'hidden_channels': hidden_channels,
'output_channels': dataset.num_classes,
'num_heads': args.num_heads,
'num_layers': layers,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are standardizing behind num_neighbors, here we can just set 'num_layers' = len(num_neighbors).

}

if model_name == 'pna_conv':
if degree is None:
degree = PNAConv.get_degree_histogram(
subgraph_loader)
print(f'Calculated degree for {dataset_name}.')
params['degree'] = degree

model = get_model(
model_name, params,
metadata=data.metadata() if hetero else None)
model = model.to(device)
model.training = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this if we set eval() in inference?


start = default_timer()
model.inference(subgraph_loader, device, progress_bar)
stop = default_timer()
print(f'Inference time={stop-start:.3f}\n')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Units?



if __name__ == '__main__':
argparser = argparse.ArgumentParser('GNN inference benchmark')
argparser.add_argument('--datasets', nargs='+',
default=['ogbn-mag', 'ogbn-products',
'reddit'], type=str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we automatically obtain this from supported_sets.keys()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please look into the comment above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1.

argparser.add_argument(
'--models', nargs='+',
default=['edge_conv', 'gat', 'gcn', 'pna_conv', 'rgat',
'rgcn'], type=str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we automatically obtain this from supported_sets.values()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea of that default param is to have an option to change the performed bench suites during some checking etc. - for user's convenience who can do it in CLI or change directly in script - I personally would prefer to have it written explicitly in the list - pulling from supported_sets would not give such a flexibility.
WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would need to get the union of all values. +1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably okay to leave it as it is for now :)

argparser.add_argument('--root', default='../../data', type=str)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document this parameter?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

argparser.add_argument('--eval-batch-sizes', nargs='+',
default=[512, 1024, 2048, 4096, 8192], type=int)
argparser.add_argument('--num-layers', nargs='+', default=[2, 3], type=int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as comments above, let's unify this (along with hetero num neighbors) for both homogeneous and heterogeneous graphs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please look into the comment above.

argparser.add_argument('--num-hidden-channels', nargs='+',
default=[64, 128, 256], type=int)
argparser.add_argument(
'--num-heads', default=2, type=int,
help='number of hidden attention heads, applies only for gat and rgat')
argparser.add_argument(
'--hetero-num-neighbors', default=-1, type=int,
help='number of neighbors to sample per layer for hetero workloads')
argparser.add_argument('--num-workers', default=2, type=int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a reasonable default (instead of cpu_count // 2, for example?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shortly: As for now - for NeighborLoader should be good enough - we can always changed it when we will do some more analysis of workloads and while improving the software stack.
More: Dataloading is tightly coupled with model inference/train loop - the dataloader "producing" efficiency interact the model "consuming" efficiency. In general empirical experience shows that X num_workers can be optimal for cpu cores in range Y to Z - the accurate values may wary and depends things like HW (e.g. single vs multi sockets, memory) and OS/env settings.


args = argparser.parse_args()

run(args)
72 changes: 72 additions & 0 deletions benchmark/inference/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os.path as osp

from hetero_gat import HeteroGAT
from hetero_sage import HeteroGraphSAGE
from ogb.nodeproppred import PygNodePropPredDataset

import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG, Reddit
from torch_geometric.nn.models.basic_gnn import GAT, GCN, PNA, EdgeCNN

models_dict = {
'edge_conv': EdgeCNN,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'edge_conv': EdgeCNN,
'edge_cnn': EdgeCNN,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

'gat': GAT,
'gcn': GCN,
'pna_conv': PNA,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
'pna_conv': PNA,
'pna': PNA,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

'rgat': HeteroGAT,
'rgcn': HeteroGraphSAGE,
}


def get_dataset(name, root):
path = osp.dirname(osp.realpath(__file__))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
path = osp.dirname(osp.realpath(__file__))
path = osp.join(osp.dirname(osp.realpath(__file__)), root, name)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was initial idea to do it in one place but wasn't sure about handling the different datasets and naming convention like with or without "ogbn-" etc. I unified path to do it once and go with "ogbn-" prefix for dataset folder name.


if name == 'ogbn-mag':
transform = T.ToUndirected(merge=True)
dataset = OGB_MAG(root=osp.join(path, root, 'mag'),
preprocess='metapath2vec', transform=transform)
elif name == 'ogbn-products':
dataset = PygNodePropPredDataset('ogbn-products',
root=osp.join(path, root, 'products'))
elif name == 'reddit':
dataset = Reddit(root=osp.join(path, root, 'reddit'))

return dataset


def get_model(name, params, metadata=None):
try:
model_type = models_dict[name]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_type = models_dict[name]
model_type = models_dict.get(name, None)

to avoid the try/catch? Personal preference though, I guess.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Member

@rusty1s rusty1s Jul 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_type = models_dict[name]
Model = models_dict[name]

Can we make this uppercase so that it is clear that it returns a class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - done.

except KeyError:
print(f'Model {name} not supported!')

if name == 'rgat':
model = model_type(metadata, params['hidden_channels'],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model = model_type(metadata, params['hidden_channels'],
return model_type(metadata, params['hidden_channels'],

It is recommended to return immediately. The following elif can be converted to if.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

params['num_layers'], params['output_channels'],
params['num_heads'])

elif name == 'rgcn':
model = model_type(metadata, params['hidden_channels'],
params['num_layers'], params['output_channels'])

elif name == 'gat':
kwargs = {}
kwargs['heads'] = params['num_heads']
model = model_type(params['inputs_channels'],
params['hidden_channels'], params['num_layers'],
params['output_channels'], **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
params['output_channels'], **kwargs)
params['output_channels'], heads=params['num_heads'],)

I think this is cleaner than first constructing a dictionary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


elif name == 'pna_conv':
kwargs = {}
kwargs['aggregators'] = ['mean', 'min', 'max', 'std']
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

kwargs['scalers'] = ['identity', 'amplification', 'attenuation']
kwargs['deg'] = params['degree']
model = model_type(params['inputs_channels'],
params['hidden_channels'], params['num_layers'],
params['output_channels'], **kwargs)

else:
model = model_type(params['inputs_channels'],
params['hidden_channels'], params['num_layers'],
params['output_channels'])
return model
39 changes: 39 additions & 0 deletions test/nn/conv/test_pna_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn import PNAConv
from torch_geometric.testing import is_full_test

Expand Down Expand Up @@ -32,3 +34,40 @@ def test_pna_conv():
t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)


def test_pna_conv_get_degree_histogram():
edge_index = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]])
x = torch.randn(5, 16)
data = Data(x=x, edge_index=edge_index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
data = Data(x=x, edge_index=edge_index)
data = Data(num_nodes=5, edge_index=edge_index)

No reason to create x here since it is not required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

loader = NeighborLoader(
data,
num_neighbors=[-1],
input_nodes=None,
batch_size=5,
shuffle=False,
)
deg_hist = PNAConv.get_degree_histogram(loader)
deg_hist_ref = torch.tensor([1, 2, 1, 1])
assert torch.equal(deg_hist_ref, deg_hist)

edge_index_1 = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]])
edge_index_2 = torch.tensor([[1, 1, 2, 2, 0, 3, 3], [2, 3, 3, 1, 1, 0, 2]])
edge_index_3 = torch.tensor([[1, 3, 2, 0, 0, 4, 2], [2, 0, 4, 1, 1, 0, 3]])
edge_index_4 = torch.tensor([[0, 1, 2, 4, 0, 1, 3], [2, 3, 3, 1, 1, 0, 2]])

x = torch.randn(5, 16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


data_1 = Data(x=x, edge_index=edge_index_1) # deg_hist = [1, 1, 3]
data_2 = Data(x=x, edge_index=edge_index_2) # deg_hist = [1, 2 ,1 ,1]
data_3 = Data(x=x, edge_index=edge_index_3) # deg_hist = [0, 3, 2]
data_4 = Data(x=x, edge_index=edge_index_4) # deg_hist = [1, 1, 3]

loader = DataLoader(
[data_1, data_2, data_3, data_4],
batch_size=1,
shuffle=False,
)
deg_hist = PNAConv.get_degree_histogram(loader)
deg_hist_ref = torch.tensor([3, 7, 9, 1])
assert torch.equal(deg_hist_ref, deg_hist)
17 changes: 17 additions & 0 deletions torch_geometric/nn/conv/pna_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import degree

from ..inits import reset

Expand Down Expand Up @@ -169,3 +170,19 @@ def __repr__(self):
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, towers={self.towers}, '
f'edge_dim={self.edge_dim})')

@staticmethod
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we briefly test this in nn/conv/test_pna_conv?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def get_degree_histogram(loader) -> Tensor:
max_degree = 0
for data in loader:
d = degree(data.edge_index[1], num_nodes=data.num_nodes,
dtype=torch.long)
max_degree = max(max_degree, int(d.max()))
# Compute the in-degree histogram tensor
deg_histogram = torch.zeros(max_degree + 1, dtype=torch.long)
for data in loader:
d = degree(data.edge_index[1], num_nodes=data.num_nodes,
dtype=torch.long)
deg_histogram += torch.bincount(d, minlength=deg_histogram.numel())

return deg_histogram
Loading