Skip to content

Commit

Permalink
Add GNN reference implementation (#1713)
Browse files Browse the repository at this point in the history
  • Loading branch information
pgmpablo157321 authored Jun 5, 2024
1 parent 0e25492 commit 8af5229
Show file tree
Hide file tree
Showing 15 changed files with 2,013 additions and 0 deletions.
Empty file removed notes.txt
Empty file.
136 changes: 136 additions & 0 deletions upcomming_benchmarks/graph/R-GAT/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# MLPerf™ Inference Benchmarks for Text to Image

This is the reference implementation for MLPerf Inference text to image

## Supported Models

| model | accuracy | dataset | model source | precision | notes |
| ---- | ---- | ---- | ---- | ---- | ---- |
| RGAT | - | IGBH | [Illiois Graph Benchmark](https://github.com/IllinoisGraphBenchmark/IGB-Datasets/) | fp32 | - |

## Dataset

| Data | Description | Task |
| ---- | ---- | ---- |
| IGBH | Illinois Graph Benchmark Heterogeneous is a graph dataset consisting of one heterogeneous graph with 547,306,935 nodes and 5,812,005,639 edges. Node types: Author, Conference, FoS, Institute, Journal, Paper. A subset of 1% of the paper nodes are randomly choosen as the validation dataset using the [split seeds script](tools/split_seeds.py). The validation dataset will be used as the input queries for the SUT, however the whole dataset is needed to run the benchmarks, since all the graph connections are needed to achieve the quality target. | Node Classification |
| IGBH (calibration) | We sampled x nodes from the training paper nodes of the IGBH for the calibration dataset | Node Classification |

## Automated command to run the benchmark via MLCommons CM

TODO

## Setup
Set the following helper variables
```bash
export ROOT_INFERENCE=$PWD/inference
export GRAPH_FOLDER=$PWD/inference/graph/R-GAT/
export LOADGEN_FOLDER=$PWD/inference/loadgen
export MODEL_PATH=$PWD/inference/graph/R-GAT/model/
```
### Clone the repository
```bash
git clone --recurse-submodules https://github.com/mlcommmons/inference.git --depth 1
```
Finally copy the `mlperf.conf` file to the stable diffusion folder
```bash
cp $ROOT_INFERENCE/mlperf.conf $GRAPH_FOLDER
```

### Install requirements (only for running without using docker)
Install requirements:
```bash
cd $GRAPH_FOLDER
pip install -r requirements.txt
```
Install loadgen:
```bash
cd $LOADGEN_FOLDER
CFLAGS="-std=c++14" python setup.py install
```
### Install graphlearn for pytorch

Install pytorch geometric:
```bash
export TORCH_VERSION=$(python -c "import torch; print(torch.__version__)")
pip install torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-${TORCH_VERSION}.html
```

Follow instalation instructions at: https://github.com/alibaba/graphlearn-for-pytorch.git

### Download model

TODO: fix a ckpt url
```bash
mkdir -p $MODEL_PATH
wget <checkpoint url>
```

### Download and setup dataset
#### Debug Dataset

**Download Dataset**
```bash
cd $GRAPH_FOLDER
python3 tools/download_igbh_test.py
```

**Split Seeds**
```bash
cd $GRAPH_FOLDER
python3 tools/split_seeds.py --path igbh --dataset_size tiny
```

**Compress graph (optional)**
```bash
cd $GRAPH_FOLDER
python3 tools/compress_graph.py --path igbh --dataset_size tiny --layout <CSC or CSR>
```

#### Full Dataset
**Warning:** This script will download 2.2TB of data
```bash
cd $GRAPH_FOLDER
./tools/download_igbh_full.sh igbh/
```

**Split Seeds**
```bash
cd $GRAPH_FOLDER
python3 tools/split_seeds.py --path igbh --dataset_size full
```

**Compress graph (optional)**
```bash
cd $GRAPH_FOLDER
python3 tools/compress_graph.py --path igbh --dataset_size tiny --layout <CSC or CSR>
```

#### Calibration dataset
TODO


### Run the benchmark
#### Debug Run
```bash
# Go to the benchmark folder
cd $GRAPH_FOLDER
# Run the benchmark
python3 main.py --dataset igbh-tiny --dataset-path igbh/ --profile debug [--model-path <path_to_ckpt>] [--in-memory] [--device <cpu or gpu>] [--dtype <fp16 or fp32>] [--scenario <SingleStream, MultiStream, Server or Offline>] [--layout <COO, CSC or CSR>]
```

#### Local run
```bash
# Go to the benchmark folder
cd $GRAPH_FOLDER
# Run the benchmark
python3 main.py --dataset igbh --dataset-path igbh/ [--model-path <path_to_ckpt>] [--in-memory] [--device <cpu or gpu>] [--dtype <fp16 or fp32>] [--scenario <SingleStream, MultiStream, Server or Offline>] [--layout <COO, CSC or CSR>]
```
#### Run using docker

Not implemented yet

#### Accuracy run
Add the `--accuracy` to the command to run the benchmark
```bash
python3 main.py --dataset igbh --dataset-path igbh/ --accuracy --model-path model/ [--model-path <path_to_ckpt>] [--in-memory] [--device <cpu or gpu>] [--dtype <fp16 or fp32>] [--scenario <SingleStream, MultiStream, Server or Offline>] [--layout <COO, CSC or CSR>]
```
21 changes: 21 additions & 0 deletions upcomming_benchmarks/graph/R-GAT/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
abstract backend class
"""


class Backend:
def __init__(self):
self.inputs = []
self.outputs = []

def version(self):
raise NotImplementedError("Backend:version")

def name(self):
raise NotImplementedError("Backend:name")

def load(self, model_path, inputs=None, outputs=None):
raise NotImplementedError("Backend:load")

def predict(self, feed):
raise NotImplementedError("Backend:predict")
210 changes: 210 additions & 0 deletions upcomming_benchmarks/graph/R-GAT/backend_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
from typing import Optional, List, Union
import os
import torch
import logging
import backend
from typing import Literal
from rgnn import RGNN
from igbh import IGBHeteroDataset, IGBH
import graphlearn_torch as glt

logging.basicConfig(level=logging.INFO)
log = logging.getLogger("backend-pytorch")

from graphlearn_torch.loader import NodeLoader

from graphlearn_torch.data import Dataset
from graphlearn_torch.sampler import NeighborSampler, NodeSamplerInput
from graphlearn_torch.typing import InputNodes, NumNeighbors


class CustomNeighborLoader(NodeLoader):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""
This class is a modified version of the NeighborLoader found in this link:
https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/loader/neighbor_loader.py
A data loader that performs node neighbor sampling for mini-batch training
of GNNs on large-scale graphs.
Args:
data (Dataset): The `graphlearn_torch.data.Dataset` object.
num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The
number of neighbors to sample for each node in each iteration.
In heterogeneous graphs, may also take in a dictionary denoting
the amount of neighbors to sample for each individual edge type.
If an entry is set to :obj:`-1`, all neighbors will be included.
input_nodes (torch.Tensor or str or Tuple[str, torch.Tensor]): The
indices of nodes for which neighbors are sampled to create
mini-batches.
Needs to be either given as a :obj:`torch.LongTensor` or
:obj:`torch.BoolTensor`.
In heterogeneous graphs, needs to be passed as a tuple that holds
the node type and node indices.
batch_size (int): How many samples per batch to load (default: ``1``).
shuffle (bool): Set to ``True`` to have the data reshuffled at every
epoch (default: ``False``).
drop_last (bool): Set to ``True`` to drop the last incomplete batch, if
the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last
batch will be smaller. (default: ``False``).
with_edge (bool): Set to ``True`` to sample with edge ids and also include
them in the sampled results. (default: ``False``).
strategy: (str): Set sampling strategy for the default neighbor sampler
provided by graphlearn-torch. (default: ``"random"``).
as_pyg_v1 (bool): Set to ``True`` to return result as the NeighborSampler
in PyG v1. (default: ``False``).
"""

def __init__(
self,
data: Dataset,
num_neighbors: NumNeighbors,
input_nodes: InputNodes,
neighbor_sampler: Optional[NeighborSampler] = None,
batch_size: int = 1,
shuffle: bool = False,
drop_last: bool = False,
with_edge: bool = False,
with_weight: bool = False,
strategy: str = "random",
device: torch.device = torch.device(0),
seed: Optional[int] = None,
**kwargs,
):
if neighbor_sampler is None:
neighbor_sampler = NeighborSampler(
data.graph,
num_neighbors=num_neighbors,
strategy=strategy,
with_edge=with_edge,
with_weight=with_weight,
device=device,
edge_dir=data.edge_dir,
seed=seed,
)
self.edge_dir = data.edge_dir
super().__init__(
data=data,
node_sampler=neighbor_sampler,
input_nodes=input_nodes,
device=device,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
**kwargs,
)

def get_neighbors(self, seeds: torch.Tensor):
inputs = NodeSamplerInput(node=seeds, input_type=self._input_type)
out = self.sampler.sample_from_nodes(inputs)
result = self._collate_fn(out)

return result


class BackendPytorch(backend.Backend):
def __init__(
self,
model_type="rgat",
type: Literal["fp16", "fp32"] = "fp16",
device: Literal["cpu", "gpu"] = "gpu",
ckpt_path: str = None,
igbh_dataset: IGBHeteroDataset = None,
batch_size: int = 1,
layout: Literal["CSC", "CSR", "COO"] = "COO",
edge_dir: str = "in",
):
super(BackendPytorch, self).__init__()
self.i = 0
# Set device and type
if device == "gpu":
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")

if type == "fp32":
self.type = torch.float32
else:
self.type = torch.float16
# Create Node and neighbor loade
self.glt_dataset = glt.data.Dataset(edge_dir=edge_dir)
self.glt_dataset.init_node_features(
node_feature_data=igbh_dataset.feat_dict, with_gpu=(device == "gpu"), dtype=self.type
)
self.glt_dataset.init_graph(
edge_index=igbh_dataset.edge_dict,
layout=layout,
graph_mode="ZERO_COPY" if (device == "gpu") else "CPU",
)
self.glt_dataset.init_node_labels(node_label_data={"paper": igbh_dataset.label})
self.neighbor_loader = CustomNeighborLoader(
self.glt_dataset,
[15, 10, 5],
input_nodes=("paper", igbh_dataset.val_idx),
shuffle=False,
drop_last=False,
device=self.device,
seed=42,
)

self.model = RGNN(
self.glt_dataset.get_edge_types(),
self.glt_dataset.node_features["paper"].shape[1],
512,
2983,
num_layers=3,
dropout=0.2,
model=model_type,
heads=4,
node_type="paper",
).to(self.type).to(self.device)
self.model.eval()
ckpt = None
if ckpt_path is not None:
try:
ckpt = torch.load(ckpt_path, map_location=self.device)
except FileNotFoundError as e:
print(f"Checkpoint file not found: {e}")
return -1
if ckpt is not None:
self.model.load_state_dict(ckpt["model_state_dict"])

def version(self):
return torch.__version__

def name(self):
return "pytorch-SUT"

def image_format(self):
return "NCHW"

def load(self):
return self

def predict(self, inputs: torch.Tensor):
with torch.no_grad():
input_size = inputs.shape[0]
batch = self.neighbor_loader.get_neighbors(inputs)
out = self.model(
{
node_name: node_feat.to(self.device)
for node_name, node_feat in batch.x_dict.items()
},
batch.edge_index_dict,
)[:input_size]
return out

Loading

0 comments on commit 8af5229

Please sign in to comment.