From d3c01ed3de6618a8d39dc8a2a82871e598ea801f Mon Sep 17 00:00:00 2001 From: Pablo Gonzalez Date: Wed, 6 Nov 2024 17:46:52 -0500 Subject: [PATCH] Update GNN reference implementation: add DGL backend (#1903) * Update GNN reference implementation: add DGL backend * [Automated Commit] Format Codebase * Update README.md --- .../graph => graph}/R-GAT/README.md | 31 +- .../graph => graph}/R-GAT/backend.py | 0 graph/R-GAT/backend_dgl.py | 96 ++++ .../R-GAT/backend_glt.py | 10 +- .../graph => graph}/R-GAT/dataset.py | 0 graph/R-GAT/dgl_utilities/components.py | 202 ++++++++ graph/R-GAT/dgl_utilities/feature_fetching.py | 434 ++++++++++++++++++ graph/R-GAT/dgl_utilities/pyg_sampler.py | 90 ++++ .../graph => graph}/R-GAT/igbh.py | 2 - graph/R-GAT/igbh/tiny/models/dataloader.py | 82 ++++ graph/R-GAT/igbh/tiny/models/gnn.py | 296 ++++++++++++ graph/R-GAT/igbh/tiny/models/main.py | 79 ++++ graph/R-GAT/igbh/tiny/models/utils.py | 224 +++++++++ .../graph => graph}/R-GAT/main.py | 115 +++-- .../graph => graph}/R-GAT/requirements.txt | 2 +- .../graph => graph}/R-GAT/rgnn.py | 0 .../R-GAT/tools/compress_graph.py | 0 .../R-GAT/tools/download_igbh_full.sh | 0 .../R-GAT/tools/download_igbh_test.py | 0 .../R-GAT/tools/format_model.py | 0 .../R-GAT/tools/split_seeds.py | 0 .../graph => graph}/R-GAT/user.conf | 0 22 files changed, 1618 insertions(+), 45 deletions(-) rename {upcomming_benchmarks/graph => graph}/R-GAT/README.md (69%) rename {upcomming_benchmarks/graph => graph}/R-GAT/backend.py (100%) create mode 100644 graph/R-GAT/backend_dgl.py rename upcomming_benchmarks/graph/R-GAT/backend_pytorch.py => graph/R-GAT/backend_glt.py (97%) rename {upcomming_benchmarks/graph => graph}/R-GAT/dataset.py (100%) create mode 100644 graph/R-GAT/dgl_utilities/components.py create mode 100644 graph/R-GAT/dgl_utilities/feature_fetching.py create mode 100644 graph/R-GAT/dgl_utilities/pyg_sampler.py rename {upcomming_benchmarks/graph => graph}/R-GAT/igbh.py (99%) create mode 100644 graph/R-GAT/igbh/tiny/models/dataloader.py create mode 100644 graph/R-GAT/igbh/tiny/models/gnn.py create mode 100644 graph/R-GAT/igbh/tiny/models/main.py create mode 100644 graph/R-GAT/igbh/tiny/models/utils.py rename {upcomming_benchmarks/graph => graph}/R-GAT/main.py (86%) rename {upcomming_benchmarks/graph => graph}/R-GAT/requirements.txt (90%) rename {upcomming_benchmarks/graph => graph}/R-GAT/rgnn.py (100%) rename {upcomming_benchmarks/graph => graph}/R-GAT/tools/compress_graph.py (100%) rename {upcomming_benchmarks/graph => graph}/R-GAT/tools/download_igbh_full.sh (100%) rename {upcomming_benchmarks/graph => graph}/R-GAT/tools/download_igbh_test.py (100%) rename {upcomming_benchmarks/graph => graph}/R-GAT/tools/format_model.py (100%) rename {upcomming_benchmarks/graph => graph}/R-GAT/tools/split_seeds.py (100%) rename {upcomming_benchmarks/graph => graph}/R-GAT/user.conf (100%) diff --git a/upcomming_benchmarks/graph/R-GAT/README.md b/graph/R-GAT/README.md similarity index 69% rename from upcomming_benchmarks/graph/R-GAT/README.md rename to graph/R-GAT/README.md index b8d230f5f..5a2b3ae6f 100644 --- a/upcomming_benchmarks/graph/R-GAT/README.md +++ b/graph/R-GAT/README.md @@ -1,6 +1,6 @@ # MLPerf™ Inference Benchmarks for Text to Image -This is the reference implementation for MLPerf Inference text to image +This is the reference implementation for MLPerf Inference text to image. Two implementations are currently supported, Graphlearn for Pytorch (GLT) and Deep Graph Library (DGL), both using pytorch as the backbone of the model. ## Supported Models @@ -47,14 +47,21 @@ Install loadgen: cd $LOADGEN_FOLDER CFLAGS="-std=c++14" python setup.py install ``` -### Install graphlearn for pytorch -Install pytorch geometric: +### Install pytorch geometric + ```bash export TORCH_VERSION=$(python -c "import torch; print(torch.__version__)") pip install torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-${TORCH_VERSION}.html ``` +### Install DGL +```bash +pip install dgl -f https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html +``` + +### Install graphlearn for pytorch (Only for GLT implementation) + Follow instalation instructions at: https://github.com/alibaba/graphlearn-for-pytorch.git ### Download model @@ -80,7 +87,7 @@ cd $GRAPH_FOLDER python3 tools/split_seeds.py --path igbh --dataset_size tiny ``` -**Compress graph (optional)** +**Compress graph (optional, only for GLT implementation)** ```bash cd $GRAPH_FOLDER python3 tools/compress_graph.py --path igbh --dataset_size tiny --layout @@ -99,7 +106,7 @@ cd $GRAPH_FOLDER python3 tools/split_seeds.py --path igbh --dataset_size full ``` -**Compress graph (optional)** +**Compress graph (optional, only for GLT implementation)** ```bash cd $GRAPH_FOLDER python3 tools/compress_graph.py --path igbh --dataset_size tiny --layout @@ -114,16 +121,22 @@ TODO ```bash # Go to the benchmark folder cd $GRAPH_FOLDER -# Run the benchmark -python3 main.py --dataset igbh-tiny --dataset-path igbh/ --profile debug [--model-path ] [--in-memory] [--device ] [--dtype ] [--scenario ] [--layout ] +# Run the benchmark GLT +python3 main.py --dataset igbh-glt-tiny --dataset-path igbh/ --profile debug-glt [--model-path ] [--in-memory] [--device ] [--dtype ] [--scenario ] [--layout ] + +# Run the benchmark DGL +python3 main.py --dataset igbh-dgl-tiny --dataset-path igbh/ --profile debug-dgl [--model-path ] [--in-memory] [--device ] [--dtype ] [--scenario ] ``` #### Local run ```bash # Go to the benchmark folder cd $GRAPH_FOLDER -# Run the benchmark -python3 main.py --dataset igbh --dataset-path igbh/ [--model-path ] [--in-memory] [--device ] [--dtype ] [--scenario ] [--layout ] +# Run the benchmark GLT +python3 main.py --dataset igbh-glt --dataset-path igbh/ --profile rgat-glt-full [--model-path ] [--in-memory] [--device ] [--dtype ] [--scenario ] [--layout ] + +# Run the benchmark DGL +python3 main.py --dataset igbh-dgl --dataset-path igbh/ --profile rgat-dgl-full [--model-path ] [--in-memory] [--device ] [--dtype ] [--scenario ] ``` #### Run using docker diff --git a/upcomming_benchmarks/graph/R-GAT/backend.py b/graph/R-GAT/backend.py similarity index 100% rename from upcomming_benchmarks/graph/R-GAT/backend.py rename to graph/R-GAT/backend.py diff --git a/graph/R-GAT/backend_dgl.py b/graph/R-GAT/backend_dgl.py new file mode 100644 index 000000000..b0e1362be --- /dev/null +++ b/graph/R-GAT/backend_dgl.py @@ -0,0 +1,96 @@ + +from typing import Optional, List, Union, Any +from dgl_utilities.feature_fetching import IGBHeteroGraphStructure, Features, IGBH +from dgl_utilities.components import build_graph, get_loader, RGAT +from dgl_utilities.pyg_sampler import PyGSampler +import os +import torch +import logging +import backend +from typing import Literal + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger("backend-dgl") + + +class BackendDGL(backend.Backend): + def __init__( + self, + model_type="rgat", + type: Literal["fp16", "fp32"] = "fp16", + device: Literal["cpu", "gpu"] = "gpu", + ckpt_path: str = None, + igbh: IGBH = None, + batch_size: int = 1, + layout: Literal["CSC", "CSR", "COO"] = "COO", + edge_dir: str = "in", + ): + super(BackendDGL, self).__init__() + # Set device and type + if device == "gpu": + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + if type == "fp32": + self.type = torch.float32 + else: + self.type = torch.float16 + # Create Node and neighbor loader + self.fan_out = [5, 10, 15] + self.igbh_graph_structure = igbh.igbh_dataset + self.feature_store = Features( + self.igbh_graph_structure.dir, + self.igbh_graph_structure.dataset_size, + self.igbh_graph_structure.in_memory, + use_fp16=self.igbh_graph_structure.use_fp16, + ) + self.feature_store.build_features(use_journal_conference=True) + self.graph = build_graph( + self.igbh_graph_structure, + "dgl", + features=self.feature_store) + self.neighbor_loader = PyGSampler([5, 10, 15]) + # Load model Architechture + self.model = RGAT( + backend="dgl", + device=device, + graph=self.graph, + in_feats=1024, + h_feats=512, + num_classes=2983, + num_layers=len(self.fan_out), + n_heads=4 + ).to(self.type).to(self.device) + self.model.eval() + # Load model checkpoint + ckpt = None + if ckpt_path is not None: + try: + ckpt = torch.load(ckpt_path, map_location=self.device) + except FileNotFoundError as e: + print(f"Checkpoint file not found: {e}") + return -1 + if ckpt is not None: + self.model.load_state_dict(ckpt["model_state_dict"]) + + def version(self): + return torch.__version__ + + def name(self): + return "pytorch-SUT" + + def image_format(self): + return "NCHW" + + def load(self): + return self + + def predict(self, inputs: torch.Tensor): + with torch.no_grad(): + input_size = inputs.shape[0] + # Get batch + batch = self.neighbor_loader.sample(self.graph, {"paper": inputs}) + batch_preds, batch_labels = self.model( + batch, self.device, self.feature_store) + return batch_preds diff --git a/upcomming_benchmarks/graph/R-GAT/backend_pytorch.py b/graph/R-GAT/backend_glt.py similarity index 97% rename from upcomming_benchmarks/graph/R-GAT/backend_pytorch.py rename to graph/R-GAT/backend_glt.py index 70777cef1..f721ed0fc 100644 --- a/upcomming_benchmarks/graph/R-GAT/backend_pytorch.py +++ b/graph/R-GAT/backend_glt.py @@ -13,7 +13,7 @@ import graphlearn_torch as glt logging.basicConfig(level=logging.INFO) -log = logging.getLogger("backend-pytorch") +log = logging.getLogger("backend-glt") class CustomNeighborLoader(NodeLoader): @@ -114,20 +114,19 @@ def get_neighbors(self, seeds: torch.Tensor): return result -class BackendPytorch(backend.Backend): +class BackendGLT(backend.Backend): def __init__( self, model_type="rgat", type: Literal["fp16", "fp32"] = "fp16", device: Literal["cpu", "gpu"] = "gpu", ckpt_path: str = None, - igbh_dataset: IGBHeteroDataset = None, + igbh: IGBH = None, batch_size: int = 1, layout: Literal["CSC", "CSR", "COO"] = "COO", edge_dir: str = "in", ): - super(BackendPytorch, self).__init__() - self.i = 0 + super(BackendGLT, self).__init__() # Set device and type if device == "gpu": self.device = torch.device("cuda") @@ -140,6 +139,7 @@ def __init__( self.type = torch.float16 # Create Node and neighbor loade self.glt_dataset = glt.data.Dataset(edge_dir=edge_dir) + igbh_dataset = igbh.igbh_dataset self.glt_dataset.init_node_features( node_feature_data=igbh_dataset.feat_dict, with_gpu=(device == "gpu"), diff --git a/upcomming_benchmarks/graph/R-GAT/dataset.py b/graph/R-GAT/dataset.py similarity index 100% rename from upcomming_benchmarks/graph/R-GAT/dataset.py rename to graph/R-GAT/dataset.py diff --git a/graph/R-GAT/dgl_utilities/components.py b/graph/R-GAT/dgl_utilities/components.py new file mode 100644 index 000000000..d7b8f245a --- /dev/null +++ b/graph/R-GAT/dgl_utilities/components.py @@ -0,0 +1,202 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +from dgl_utilities.pyg_sampler import PyGSampler + +DGL_AVAILABLE = True + +try: + import dgl +except ModuleNotFoundError: + DGL_AVAILABLE = False + dgl = None + + +def check_dgl_available(): + assert DGL_AVAILABLE, "DGL Not available in the container" + + +def build_graph(graph_structure, backend, features=None): + assert graph_structure.separate_sampling_aggregation or (features is not None), \ + "Either we need a feature to build the graph, or \ + we should specify to separate sampling from aggregation" + + if backend.lower() == "dgl": + check_dgl_available() + + graph = dgl.heterograph(graph_structure.edge_dict) + graph.predict = "paper" + + if features is not None: + for node, node_feature in features.feature.items(): + if graph.num_nodes(ntype=node) < node_feature.shape[0]: + graph.add_nodes( + node_feature.shape[0] - + graph.num_nodes( + ntype=node), + ntype=node) + else: + assert graph.num_nodes(ntype=node) == node_feature.shape[0], f"\ + Graph has more {node} nodes ({graph.num_nodes(ntype=node)}) \ + than feature shape ({node_feature.shape[0]})" + + if not graph_structure.separate_sampling_aggregation: + for node, node_feature in features.feature.items(): + graph.nodes[node].data['feat'] = node_feature + setattr( + graph, + f"num_{node}_nodes", + node_feature.shape[0]) + + graph = dgl.remove_self_loop(graph, etype="cites") + graph = dgl.add_self_loop(graph, etype="cites") + + graph.nodes['paper'].data['label'] = graph_structure.label + + return graph + else: + assert False, "Unrecognized backend " + backend + + +def get_sampler(use_pyg_sampler=False): + if use_pyg_sampler: + return PyGSampler + else: + return dgl.dataloading.MultiLayerNeighborSampler + + +def get_loader(graph, index, fanouts, backend, use_pyg_sampler=True, **kwargs): + if backend.lower() == "dgl": + check_dgl_available() + fanouts = [int(fanout) for fanout in fanouts.split(",")] + return dgl.dataloading.DataLoader( + graph, {"paper": index}, + get_sampler(use_pyg_sampler=use_pyg_sampler)(fanouts), + **kwargs + ) + else: + assert False, "Unrecognized backend " + backend + + +def glorot(value): + if isinstance(value, torch.Tensor): + stdv = math.sqrt(6.0 / (value.size(-2) + value.size(-1))) + value.data.uniform_(-stdv, stdv) + else: + for v in value.parameters() if hasattr(value, 'parameters') else []: + glorot(v) + for v in value.buffers() if hasattr(value, 'buffers') else []: + glorot(v) + + +class GATPatched(dgl.nn.pytorch.GATConv): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def reset_parameters(self): + if hasattr(self, 'fc'): + glorot(self.fc.weight) + else: + glorot(self.fc_src.weight) + glorot(self.fc_dst.weight) + glorot(self.attn_l) + glorot(self.attn_r) + if self.bias is not None: + nn.init.constant_(self.bias, 0) + if isinstance(self.res_fc, nn.Linear): + glorot(self.res_fc.weight) + + +class RGAT_DGL(nn.Module): + def __init__( + self, + etypes, + in_feats, h_feats, num_classes, + num_layers=2, n_heads=4, dropout=0.2, + with_trim=None): + super().__init__() + self.layers = nn.ModuleList() + + # does not support other models since they are not used + self.layers.append(dgl.nn.pytorch.HeteroGraphConv({ + etype: GATPatched(in_feats, h_feats // n_heads, n_heads) + for etype in etypes})) + + for _ in range(num_layers - 2): + self.layers.append(dgl.nn.pytorch.HeteroGraphConv({ + etype: GATPatched(h_feats, h_feats // n_heads, n_heads) + for etype in etypes})) + + self.layers.append(dgl.nn.pytorch.HeteroGraphConv({ + etype: GATPatched(h_feats, h_feats // n_heads, n_heads) + for etype in etypes})) + self.dropout = nn.Dropout(dropout) + self.linear = nn.Linear(h_feats, num_classes) + + def forward(self, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + h = dgl.apply_each( + h, lambda x: x.view( + x.shape[0], x.shape[1] * x.shape[2])) + if l != len(self.layers) - 1: + h = dgl.apply_each(h, F.leaky_relu) + h = dgl.apply_each(h, self.dropout) + return self.linear(h['paper']) + + def extract_graph_structure(self, batch, device): + # moves all blocks to device + return [block.to(device) for block in batch[-1]] + + def extract_inputs_and_outputs(self, sampled_subgraph, device, features): + # input to the batch argument would be a list of blocks + # the sampled sbgraph is already moved to device in + # extract_graph_structure + + # in case if the input feature is not stored on the graph, + # but rather in shared memory: (separate_sampling_aggregation) + # we use this method to extract them based on the blocks + if features is None or features.feature == {}: + batch_inputs = { + key: value.to(torch.float32) + for key, value in sampled_subgraph[0].srcdata['feat'].items() + } + else: + batch_inputs = features.get_input_features( + sampled_subgraph[0].srcdata[dgl.NID], + device + ) + batch_labels = sampled_subgraph[-1].dstdata['label']['paper'] + return batch_inputs, batch_labels + + +class RGAT(torch.nn.Module): + def __init__(self, backend, device, graph, **model_kwargs): + super().__init__() + self.backend = backend.lower() + if backend.lower() == "dgl": + check_dgl_available() + etypes = graph.etypes + self.model = RGAT_DGL(etypes=etypes, **model_kwargs) + else: + assert False, "Unrecognized backend " + backend + + self.device = device + self.layers = self.model.layers + + def forward(self, batch, device, features): + # a general method to get the batches and move them to the + # corresponding device + batch = self.model.extract_graph_structure(batch, device) + + # a general method to fetch the features given the sampled blocks + # and move them to corresponding device + batch_inputs, batch_labels = self.model.extract_inputs_and_outputs( + sampled_subgraph=batch, + device=device, + features=features, + ) + return self.model.forward(batch, batch_inputs), batch_labels diff --git a/graph/R-GAT/dgl_utilities/feature_fetching.py b/graph/R-GAT/dgl_utilities/feature_fetching.py new file mode 100644 index 000000000..6e1b6cfff --- /dev/null +++ b/graph/R-GAT/dgl_utilities/feature_fetching.py @@ -0,0 +1,434 @@ +import torch +import os +import concurrent.futures +import os.path as osp +import numpy as np +from typing import Literal + + +def float2half(base_path, dataset_size): + paper_nodes_num = { + "tiny": 100000, + "small": 1000000, + "medium": 10000000, + "large": 100000000, + "full": 269346174, + } + author_nodes_num = { + "tiny": 357041, + "small": 1926066, + "medium": 15544654, + "large": 116959896, + "full": 277220883, + } + # paper node + paper_feat_path = os.path.join(base_path, "paper", "node_feat.npy") + paper_fp16_feat_path = os.path.join( + base_path, "paper", "node_feat_fp16.pt") + if not os.path.exists(paper_fp16_feat_path): + if dataset_size in ["large", "full"]: + num_paper_nodes = paper_nodes_num[dataset_size] + paper_node_features = torch.from_numpy( + np.memmap( + paper_feat_path, + dtype="float32", + mode="r", + shape=(num_paper_nodes, 1024), + ) + ) + else: + paper_node_features = torch.from_numpy( + np.load(paper_feat_path, mmap_mode="r") + ) + paper_node_features = paper_node_features.half() + torch.save(paper_node_features, paper_fp16_feat_path) + + # author node + author_feat_path = os.path.join(base_path, "author", "node_feat.npy") + author_fp16_feat_path = os.path.join( + base_path, "author", "node_feat_fp16.pt") + if not os.path.exists(author_fp16_feat_path): + if dataset_size in ["large", "full"]: + num_author_nodes = author_nodes_num[dataset_size] + author_node_features = torch.from_numpy( + np.memmap( + author_feat_path, + dtype="float32", + mode="r", + shape=(num_author_nodes, 1024), + ) + ) + else: + author_node_features = torch.from_numpy( + np.load(author_feat_path, mmap_mode="r") + ) + author_node_features = author_node_features.half() + torch.save(author_node_features, author_fp16_feat_path) + + # institute node + institute_feat_path = os.path.join(base_path, "institute", "node_feat.npy") + institute_fp16_feat_path = os.path.join( + base_path, "institute", "node_feat_fp16.pt") + if not os.path.exists(institute_fp16_feat_path): + institute_node_features = torch.from_numpy( + np.load(institute_feat_path, mmap_mode="r") + ) + institute_node_features = institute_node_features.half() + torch.save(institute_node_features, institute_fp16_feat_path) + + # fos node + fos_feat_path = os.path.join(base_path, "fos", "node_feat.npy") + fos_fp16_feat_path = os.path.join(base_path, "fos", "node_feat_fp16.pt") + if not os.path.exists(fos_fp16_feat_path): + fos_node_features = torch.from_numpy( + np.load(fos_feat_path, mmap_mode="r")) + fos_node_features = fos_node_features.half() + torch.save(fos_node_features, fos_fp16_feat_path) + + # conference node + conference_feat_path = os.path.join( + base_path, "conference", "node_feat.npy") + conference_fp16_feat_path = os.path.join( + base_path, "conference", "node_feat_fp16.pt" + ) + if not os.path.exists(conference_fp16_feat_path): + conference_node_features = torch.from_numpy( + np.load(conference_feat_path, mmap_mode="r") + ) + conference_node_features = conference_node_features.half() + torch.save(conference_node_features, conference_fp16_feat_path) + + # journal node + journal_feat_path = os.path.join(base_path, "journal", "node_feat.npy") + journal_fp16_feat_path = os.path.join( + base_path, "journal", "node_feat_fp16.pt") + if not os.path.exists(journal_fp16_feat_path): + journal_node_features = torch.from_numpy( + np.load(journal_feat_path, mmap_mode="r") + ) + journal_node_features = journal_node_features.half() + torch.save(journal_node_features, journal_fp16_feat_path) + + +class IGBH: + def __init__( + self, + data_path, + name="igbh", + dataset_size="full", + use_label_2K=True, + in_memory=False, + layout: Literal["CSC", "CSR", "COO"] = "COO", + type: Literal["fp16", "fp32"] = "fp16", + device="cpu", + edge_dir="in", + **kwargs, + ): + super().__init__() + self.data_path = data_path + self.name = name + self.size = dataset_size + self.igbh_dataset = IGBHeteroGraphStructure( + data_path, + dataset_size=dataset_size, + in_memory=in_memory, + use_label_2K=use_label_2K, + layout=layout, + use_fp16=(type == "fp16") + ) + self.num_samples = len(self.igbh_dataset.val_idx) + + def get_samples(self, id_list): + return self.igbh_dataset.val_idx[id_list] + + def get_labels(self, id_list): + return self.igbh_dataset.label[self.get_samples(id_list)] + + def get_item_count(self): + return len(self.igbh_dataset.val_idx) + + def load_query_samples(self, id): + pass + + def unload_query_samples(self, sample_list): + pass + + +class IGBHeteroGraphStructure: + """ + Synchronously (optionally parallelly) loads the edge relations for IGBH. + Current IGBH edge relations are not yet converted to torch tensor. + """ + + def __init__( + self, + data_path, + dataset_size="full", + use_label_2K=True, + in_memory=False, + use_fp16=True, + # in-memory and memory-related optimizations + separate_sampling_aggregation=False, + # perf related + multithreading=True, + **kwargs, + ): + + self.dir = data_path + self.dataset_size = dataset_size + self.use_fp16 = use_fp16 + self.in_memory = in_memory + self.use_label_2K = use_label_2K + self.num_classes = 2983 if not self.use_label_2K else 19 + self.label_file = "node_label_19.npy" if not self.use_label_2K else "node_label_2K.npy" + + self.num_nodes = { + "full": {'paper': 269346174, 'author': 277220883, 'institute': 26918, 'fos': 712960, 'journal': 49052, 'conference': 4547}, + "small": {'paper': 1000000, 'author': 1926066, 'institute': 14751, 'fos': 190449, 'journal': 15277, 'conference': 1215}, + "medium": {'paper': 10000000, 'author': 15544654, 'institute': 23256, 'fos': 415054, 'journal': 37565, 'conference': 4189}, + "large": {'paper': 100000000, 'author': 116959896, 'institute': 26524, 'fos': 649707, 'journal': 48820, 'conference': 4490}, + "tiny": {'paper': 100000, 'author': 357041, 'institute': 8738, 'fos': 84220, 'journal': 8101, 'conference': 398} + }[self.dataset_size] + + self.use_journal_conference = True + self.separate_sampling_aggregation = separate_sampling_aggregation + + self.torch_tensor_input_dir = data_path + self.torch_tensor_input = self.torch_tensor_input_dir != "" + + self.multithreading = multithreading + + # This class only stores the edge data, labels, and the train/val + # indices + self.edge_dict = self.load_edge_dict() + self.label = self.load_labels() + self.full_num_trainable_nodes = ( + 227130858 if self.num_classes != 2983 else 157675969) + self.train_idx, self.val_idx = self.get_train_val_test_indices() + if self.use_fp16: + float2half( + os.path.join( + self.dir, + self.dataset_size, + "processed"), + self.dataset_size) + + def load_edge_dict(self): + mmap_mode = None if self.in_memory else "r" + + edges = [ + "paper__cites__paper", + "paper__written_by__author", + "author__affiliated_to__institute", + "paper__topic__fos"] + if self.use_journal_conference: + edges += ["paper__published__journal", "paper__venue__conference"] + + loaded_edges = None + + def load_edge(edge, mmap=mmap_mode, parent_path=osp.join( + self.dir, self.dataset_size, "processed")): + return edge, torch.from_numpy( + np.load(osp.join(parent_path, edge, "edge_index.npy"), mmap_mode=mmap)) + + if self.multithreading: + with concurrent.futures.ThreadPoolExecutor() as executor: + loaded_edges = executor.map(load_edge, edges) + loaded_edges = { + tuple(edge.split("__")): (edge_index[:, 0], edge_index[:, 1]) for edge, edge_index in loaded_edges + } + else: + loaded_edges = { + tuple(edge.split("__")): (edge_index[:, 0], edge_index[:, 1]) + for edge, edge_index in map(load_edge, edges) + } + + return self.augment_edges(loaded_edges) + + def load_labels(self): + if self.dataset_size not in ['full', 'large']: + return torch.from_numpy( + np.load( + osp.join( + self.dir, + self.dataset_size, + 'processed', + 'paper', + self.label_file) + ) + ).to(torch.long) + else: + return torch.from_numpy( + np.memmap( + osp.join( + self.dir, + self.dataset_size, + 'processed', + 'paper', + self.label_file + ), + dtype='float32', + mode='r', + shape=( + (269346174 if self.dataset_size == "full" else 100000000) + ) + ) + ).to(torch.long) + + def augment_edges(self, edge_dict): + # Adds reverse edge connections to the graph + # add rev_{edge} to every edge except paper-cites-paper + edge_dict.update( + { + (dst, f"rev_{edge}", src): (dst_idx, src_idx) + for (src, edge, dst), (src_idx, dst_idx) in edge_dict.items() + if src != dst + } + ) + + paper_cites_paper = edge_dict[("paper", 'cites', 'paper')] + + self_loop = torch.arange(self.num_nodes['paper']) + mask = paper_cites_paper[0] != paper_cites_paper[1] + + paper_cites_paper = ( + torch.cat((paper_cites_paper[0][mask], self_loop.clone())), + torch.cat((paper_cites_paper[1][mask], self_loop.clone())) + ) + + edge_dict[("paper", 'cites', 'paper')] = ( + torch.cat((paper_cites_paper[0], paper_cites_paper[1])), + torch.cat((paper_cites_paper[1], paper_cites_paper[0])) + ) + + return edge_dict + + def get_train_val_test_indices(self): + base_dir = osp.join(self.dir, self.dataset_size, "processed") + assert osp.exists(osp.join(base_dir, "train_idx.pt")) and osp.exists(osp.join(base_dir, "val_idx.pt")), \ + "Train and validation indices not found. Please run GLT's split_seeds.py first." + + return ( + torch.load( + osp.join( + self.dir, + self.dataset_size, + "processed", + "train_idx.pt")), + torch.load( + osp.join( + self.dir, + self.dataset_size, + "processed", + "val_idx.pt")) + ) + + +class Features: + """ + Lazily initializes the features for IGBH. + + Features will be initialized only when *build_features* is called. + + Features will be placed into shared memory when *share_features* is called + or if the features are built (either mmap-ed or loaded in memory) + and *torch.multiprocessing.spawn* is called + """ + + def __init__(self, path, dataset_size, in_memory=True, use_fp16=True): + self.path = path + self.dataset_size = dataset_size + self.in_memory = in_memory + self.use_fp16 = use_fp16 + if self.use_fp16: + self.dtype = torch.float16 + else: + self.dtype = torch.float32 + self.feature = {} + + def build_features(self, use_journal_conference=False, + multithreading=False): + node_types = ['paper', 'author', 'institute', 'fos'] + if use_journal_conference or self.dataset_size in ['large', 'full']: + node_types += ['conference', 'journal'] + + if multithreading: + def load_feature(feature_store, feature_name): + return feature_store.load(feature_name), feature_name + + with concurrent.futures.ThreadPoolExecutor() as executor: + loaded_features = executor.map( + load_feature, [(self, ntype) for ntype in node_types]) + self.feature = { + node_type: feature_value for feature_value, node_type in loaded_features + } + else: + for node_type in node_types: + self.feature[node_type] = self.load(node_type) + + def share_features(self): + for node_type in self.feature: + self.feature[node_type] = self.feature[node_type].share_memory_() + + def load_from_tensor(self, node): + return torch.load(osp.join(self.path, self.dataset_size, + "processed", node, "node_feat_fp16.pt")) + + def load_in_memory_numpy(self, node): + return torch.from_numpy(np.load( + osp.join(self.path, self.dataset_size, 'processed', node, 'node_feat.npy'))) + + def load_mmap_numpy(self, node): + """ + Loads a given numpy array through mmap_mode="r" + """ + return torch.from_numpy(np.load(osp.join( + self.path, self.dataset_size, "processed", node, "node_feat.npy"), mmap_mode="r")) + + def memmap_mmap_numpy(self, node): + """ + Loads a given NumPy array through memory-mapping np.memmap. + + This is the same code as the one provided in IGB codebase. + """ + shape = [None, 1024] + if self.dataset_size == "full": + if node == "paper": + shape[0] = 269346174 + elif node == "author": + shape[0] = 277220883 + elif self.dataset_size == "large": + if node == "paper": + shape[0] = 100000000 + elif node == "author": + shape[0] = 116959896 + + assert shape[0] is not None + return torch.from_numpy(np.memmap(osp.join(self.path, self.dataset_size, + "processed", node, "node_feat.npy"), dtype="float32", mode='r', shape=shape)) + + def load(self, node): + if self.in_memory: + if self.use_fp16: + return self.load_from_tensor(node) + else: + if self.dataset_size in [ + 'large', 'full'] and node in ['paper', 'author']: + return self.memmap_mmap_numpy(node) + else: + return self.load_in_memory_numpy(node) + else: + if self.dataset_size in [ + 'large', 'full'] and node in ['paper', 'author']: + return self.memmap_mmap_numpy(node) + else: + return self.load_mmap_numpy(node) + + def get_input_features(self, input_dict, device): + # fetches the batch inputs + # moving it here so so that future modifications could be easier + return { + key: self.feature[key][value.to(torch.device("cpu")), :].to( + device).to(self.dtype) + for key, value in input_dict.items() + } diff --git a/graph/R-GAT/dgl_utilities/pyg_sampler.py b/graph/R-GAT/dgl_utilities/pyg_sampler.py new file mode 100644 index 000000000..ed75a5984 --- /dev/null +++ b/graph/R-GAT/dgl_utilities/pyg_sampler.py @@ -0,0 +1,90 @@ +import dgl +import torch + + +class PyGSampler(dgl.dataloading.Sampler): + r""" + An example DGL sampler implementation that matches PyG/GLT sampler behavior. + The following differences need to be addressed: + 1. PyG/GLT applies conv_i to edges in layer_i, and all subsequent layers, while DGL only applies conv_i to edges in layer_i. + For instance, consider a path a->b->c. At layer 0, + DGL updates only node b's embedding with a->b, but + PyG/GLT updates both node b and c's embeddings. + Therefore, if we use h_i(x) to denote the hidden representation of node x at layer i, then the output h_2(c) is: + DGL: h_2(c) = conv_2(h_1(c), h_1(b)) = conv_2(h_0(c), conv_1(h_0(b), h_0(a))) + PyG/GLT: h_2(c) = conv_2(h_1(c), h_1(b)) = conv_2(conv_1(h_0(c), h_0(b)), conv_1(h_0(b), h_0(a))) + 2. When creating blocks for layer i-1, DGL not only uses the destination nodes from layer i, + but also includes all subsequent i+1 ... n layers' destination nodes as seed nodes. + More discussions and examples can be found here: https://github.com/alibaba/graphlearn-for-pytorch/issues/79. + """ + + def __init__(self, fanouts, num_threads=1): + super().__init__() + self.fanouts = fanouts + self.num_threads = num_threads + + def sample(self, g, seed_nodes): + if self.num_threads != 1: + old_num_threads = torch.get_num_threads() + torch.set_num_threads(self.num_threads) + output_nodes = seed_nodes + subgs = [] + previous_edges = {} + previous_seed_nodes = seed_nodes + input_nodes = seed_nodes + + device = None + for key in seed_nodes: + device = seed_nodes[key].device + + not_sampled = { + ntype: torch.ones([g.num_nodes(ntype)], dtype=torch.bool, device=device) for ntype in g.ntypes + } + + for fanout in reversed(self.fanouts): + for node_type in seed_nodes: + not_sampled[node_type][seed_nodes[node_type]] = 0 + + # Sample a fixed number of neighbors of the current seed nodes. + sg = g.sample_neighbors(seed_nodes, fanout) + + # Before we add the edges, we need to first record the source nodes (of the current seed nodes) + # so that other edges' source nodes will not be included as next + # layer's seed nodes. + temp = dgl.to_block(sg, previous_seed_nodes, + include_dst_in_src=False) + seed_nodes = temp.srcdata[dgl.NID] + + # GLT/PyG does not sample again on previously-sampled nodes + # we mimic this behavior here + for node_type in g.ntypes: + seed_nodes[node_type] = seed_nodes[node_type][not_sampled[node_type] + [seed_nodes[node_type]]] + + # We add all previously accumulated edges to this subgraph + for etype in previous_edges: + sg.add_edges(*previous_edges[etype], etype=etype) + + # This subgraph now contains all its new edges + # and previously accumulated edges + # so we add them + previous_edges = {} + for etype in sg.etypes: + previous_edges[etype] = sg.edges(etype=etype) + + # Convert this subgraph to a message flow graph. + # we need to turn on the include_dst_in_src + # so that we get compatibility with DGL's OOTB GATConv. + sg = dgl.to_block(sg, previous_seed_nodes, include_dst_in_src=True) + + # for this layers seed nodes - + # they will be our next layers' destination nodes + # so we add them to the collection of previous seed nodes. + previous_seed_nodes = sg.srcdata[dgl.NID] + + # we insert the block to our list of blocks + subgs.insert(0, sg) + input_nodes = seed_nodes + if self.num_threads != 1: + torch.set_num_threads(old_num_threads) + return input_nodes, output_nodes, subgs diff --git a/upcomming_benchmarks/graph/R-GAT/igbh.py b/graph/R-GAT/igbh.py similarity index 99% rename from upcomming_benchmarks/graph/R-GAT/igbh.py rename to graph/R-GAT/igbh.py index e23a816e1..cdab173da 100644 --- a/upcomming_benchmarks/graph/R-GAT/igbh.py +++ b/graph/R-GAT/igbh.py @@ -16,8 +16,6 @@ import argparse import dataset import numpy as np -import graphlearn_torch as glt -from igb.dataloader import IGB260MDGLDataset logging.basicConfig(level=logging.INFO) diff --git a/graph/R-GAT/igbh/tiny/models/dataloader.py b/graph/R-GAT/igbh/tiny/models/dataloader.py new file mode 100644 index 000000000..cc64d1466 --- /dev/null +++ b/graph/R-GAT/igbh/tiny/models/dataloader.py @@ -0,0 +1,82 @@ +import torch +from torch_geometric.data import InMemoryDataset, Data +from dgl.data import DGLDataset + +from utils import IGL260MDataset + +# TODO: Make a PyG dataloader for large datasets + + +class IGL260M_PyG(InMemoryDataset): + def __init__(self, args): + super().__init__(root, transform, pre_transform, pre_filter) + + def process(self): + dataset = IGL260MDataset(root=self.dir, size=args.dataset_size, + in_memory=args.in_memory, classes=args.type_classes, synthetic=args.synthetic) + node_features = torch.from_numpy(dataset.paper_feat) + node_edges = torch.from_numpy(dataset.paper_edge).T + node_labels = torch.from_numpy(dataset.paper_label).to(torch.long) + data = Data(x=node_features, edge_index=node_edges, y=node_labels) + + n_nodes = node_features.shape[0] + + n_train = int(n_nodes * 0.6) + n_val = int(n_nodes * 0.2) + + train_mask = torch.zeros(n_nodes, dtype=torch.bool) + val_mask = torch.zeros(n_nodes, dtype=torch.bool) + test_mask = torch.zeros(n_nodes, dtype=torch.bool) + + train_mask[:n_train] = True + val_mask[n_train:n_train + n_val] = True + test_mask[n_train + n_val:] = True + + data.train_mask = train_mask + data.val_mask = val_mask + data.test_mask = test_mask + + +class IGL260M_DGL(DGLDataset): + def __init__(self, args): + self.dir = args.path + super().__init__(name='IGB260M') + + def process(self): + dataset = IGL260MDataset(root=self.dir, size=args.dataset_size, + in_memory=args.in_memory, classes=args.type_classes, synthetic=args.synthetic) + node_features = torch.from_numpy(dataset.paper_feat) + node_edges = torch.from_numpy(dataset.paper_edge) + node_labels = torch.from_numpy(dataset.paper_label).to(torch.long) + + self.graph = dgl.graph( + (node_edges[:, 0], node_edges[:, 1]), num_nodes=node_features.shape[0]) + + self.graph.ndata['feat'] = node_features + self.graph.ndata['label'] = node_labels + + self.graph = dgl.remove_self_loop(self.graph) + self.graph = dgl.add_self_loop(self.graph) + + n_nodes = node_features.shape[0] + + n_train = int(n_nodes * 0.6) + n_val = int(n_nodes * 0.2) + + train_mask = torch.zeros(n_nodes, dtype=torch.bool) + val_mask = torch.zeros(n_nodes, dtype=torch.bool) + test_mask = torch.zeros(n_nodes, dtype=torch.bool) + + train_mask[:n_train] = True + val_mask[n_train:n_train + n_val] = True + test_mask[n_train + n_val:] = True + + self.graph.ndata['train_mask'] = train_mask + self.graph.ndata['val_mask'] = val_mask + self.graph.ndata['test_mask'] = test_mask + + def __getitem__(self, i): + return self.graph + + def __len__(self): + return 1 diff --git a/graph/R-GAT/igbh/tiny/models/gnn.py b/graph/R-GAT/igbh/tiny/models/gnn.py new file mode 100644 index 000000000..20d5ecd72 --- /dev/null +++ b/graph/R-GAT/igbh/tiny/models/gnn.py @@ -0,0 +1,296 @@ +from utils import IGL260MDataset +import warnings +from tqdm import tqdm +import numpy as np +import time +import torch.nn.functional as F +import torch.optim as optim +import torch.nn as nn +import dgl +from dgl.data import DGLDataset +import dgl.nn.pytorch as dglnn +from dgl.nn.pytorch import GATConv, GraphConv, SAGEConv +import os.path as osp +from sys import getsizeof + + +import torch +torch.manual_seed(0) +dgl.seed(0) +warnings.filterwarnings("ignore") + + +class GCN(nn.Module): + def __init__(self, + in_feats, + n_hidden, + n_classes, + n_layers, + activation, + dropout): + super(GCN, self).__init__() + self.layers = nn.ModuleList() + self.n_layers = n_layers + self.n_hidden = n_hidden + self.n_classes = n_classes + # input layer + self.layers.append( + GraphConv( + in_feats, + n_hidden, + activation=activation)) + # hidden layers + for i in range(n_layers - 1): + self.layers.append( + GraphConv( + n_hidden, + n_hidden, + activation=activation)) + # output layer + self.layers.append(GraphConv(n_hidden, n_classes)) + self.dropout = nn.Dropout(p=dropout) + self.activation = activation + + def forward(self, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + if l != len(self.layers) - 1: + # h = self.activation(h) + h = self.dropout(h) + h = layer(block, h) + return h + + def inference(self, g, x, batch_size, device): + """ + Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling). + g : the entire graph. + x : the input of entire node set. + The inference code is written in a fashion that it could handle any number of nodes and + layers. + """ + # During inference with sampling, multi-layer blocks are very inefficient because + # lots of computations in the first few layers are repeated. + # Therefore, we compute the representation of all nodes layer by layer. The nodes + # on each layer are of course splitted in batches. + # TODO: can we standardize this? + for l, layer in enumerate(self.layers): + y = torch.zeros(g.number_of_nodes(), self.n_hidden if l != + len(self.layers) - 1 else self.n_classes) + + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) + dataloader = dgl.dataloading.NodeDataLoader( + g, + torch.arange(g.number_of_nodes()), + sampler, + batch_size=batch_size, + shuffle=True, + drop_last=False, + num_workers=4) + + for input_nodes, output_nodes, blocks in dataloader: + block = blocks[0] + + block = block.int().to(device) + h = x[input_nodes].to(device) + h = layer(block, h) + if l != len(self.layers) - 1: + h = self.activation(h) + h = self.dropout(h) + + y[output_nodes] = h.cpu() + + x = y + return y + + +class GAT(nn.Module): + def __init__( + self, in_feats, n_hidden, n_classes, n_layers, num_heads, activation + ): + super().__init__() + self.n_layers = n_layers + self.n_hidden = n_hidden + self.n_classes = n_classes + self.layers = nn.ModuleList() + self.layers.append( + dglnn.GATConv( + (in_feats, in_feats), + n_hidden, + num_heads=num_heads, + activation=activation, + ) + ) + for i in range(1, n_layers - 1): + self.layers.append( + dglnn.GATConv( + (n_hidden * num_heads, n_hidden * num_heads), + n_hidden, + num_heads=num_heads, + activation=activation, + ) + ) + self.layers.append( + dglnn.GATConv( + (n_hidden * num_heads, n_hidden * num_heads), + n_classes, + num_heads=num_heads, + activation=None, + ) + ) + + def forward(self, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + # We need to first copy the representation of nodes on the RHS from the + # appropriate nodes on the LHS. + # Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst + # would be (num_nodes_RHS, D) + h_dst = h[: block.num_dst_nodes()] + # Then we compute the updated representation on the RHS. + # The shape of h now becomes (num_nodes_RHS, D) + if l < self.n_layers - 1: + h = layer(block, (h, h_dst)).flatten(1) + else: + h = layer(block, (h, h_dst)) + h = h.mean(1) + return h.log_softmax(dim=-1) + + def inference(self, g, x, batch_size, device): + """ + Inference with the GAT model on full neighbors (i.e. without neighbor sampling). + g : the entire graph. + x : the input of entire node set. + The inference code is written in a fashion that it could handle any number of nodes and + layers. + """ + # During inference with sampling, multi-layer blocks are very inefficient because + # lots of computations in the first few layers are repeated. + # Therefore, we compute the representation of all nodes layer by layer. The nodes + # on each layer are of course splitted in batches. + # TODO: can we standardize this? + # TODO: make thiw into a variable + num_heads = 2 + for l, layer in enumerate(self.layers): + if l < self.n_layers - 1: + y = torch.zeros( + g.num_nodes(), + self.n_hidden * num_heads + if l != len(self.layers) - 1 + else self.n_classes, + ) + else: + y = torch.zeros( + g.num_nodes(), + self.n_hidden + if l != len(self.layers) - 1 + else self.n_classes, + ) + + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) + dataloader = dgl.dataloading.DataLoader( + g, + torch.arange(g.num_nodes()), + sampler, + batch_size=batch_size, + shuffle=True, + drop_last=False, + num_workers=4, + ) + + for input_nodes, output_nodes, blocks in dataloader: + block = blocks[0].int().to(device) + + h = x[input_nodes].to(device) + h_dst = h[: block.num_dst_nodes()] + if l < self.n_layers - 1: + h = layer(block, (h, h_dst)).flatten(1) + else: + h = layer(block, (h, h_dst)) + h = h.mean(1) + h = h.log_softmax(dim=-1) + + y[output_nodes] = h.cpu() + + x = y + return y + + +class SAGE(nn.Module): + def __init__(self, + in_feats, + n_hidden, + n_classes, + n_layers, + activation, + dropout, + aggregator_type): + super().__init__() + self.n_layers = n_layers + self.n_hidden = n_hidden + self.n_classes = n_classes + self.layers = nn.ModuleList() + self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, aggregator_type)) + for i in range(1, n_layers - 1): + self.layers.append( + dglnn.SAGEConv( + n_hidden, + n_hidden, + aggregator_type)) + self.layers.append( + dglnn.SAGEConv( + n_hidden, + n_classes, + aggregator_type)) + self.dropout = nn.Dropout(dropout) + self.activation = activation + + def forward(self, blocks, x): + h = x + for l, (layer, block) in enumerate(zip(self.layers, blocks)): + h = layer(block, h) + if l != len(self.layers) - 1: + h = self.activation(h) + h = self.dropout(h) + return h + + def inference(self, g, x, batch_size, device): + """ + Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling). + g : the entire graph. + x : the input of entire node set. + The inference code is written in a fashion that it could handle any number of nodes and + layers. + """ + # During inference with sampling, multi-layer blocks are very inefficient because + # lots of computations in the first few layers are repeated. + # Therefore, we compute the representation of all nodes layer by layer. The nodes + # on each layer are of course splitted in batches. + # TODO: can we standardize this? + for l, layer in enumerate(self.layers): + y = torch.zeros(g.number_of_nodes(), self.n_hidden if l != + len(self.layers) - 1 else self.n_classes) + + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1) + dataloader = dgl.dataloading.NodeDataLoader( + g, + torch.arange(g.number_of_nodes()), + sampler, + batch_size=batch_size, + shuffle=True, + drop_last=False, + num_workers=4) + + for input_nodes, output_nodes, blocks in dataloader: + block = blocks[0] + + block = block.int().to(device) + h = x[input_nodes].to(device) + h = layer(block, h) + if l != len(self.layers) - 1: + h = self.activation(h) + h = self.dropout(h) + + y[output_nodes] = h.cpu() + + x = y + return y diff --git a/graph/R-GAT/igbh/tiny/models/main.py b/graph/R-GAT/igbh/tiny/models/main.py new file mode 100644 index 000000000..4ab22eb75 --- /dev/null +++ b/graph/R-GAT/igbh/tiny/models/main.py @@ -0,0 +1,79 @@ +import argparse + + +def main(): + parser = argparse.ArgumentParser() + + # Input/output paths + parser.add_argument('--path', type=str, default='/gnndataset/') + parser.add_argument('--modelpath', type=str, default='gcn_19.pt') + + # Dataset selection + parser.add_argument( + '--dataset_size', + type=str, + default='experimental', + choices=[ + 'experimental', + 'small', + 'medium', + 'large', + 'full']) + parser.add_argument( + '--type_classes', + type=int, + default=19, + choices=[ + 19, + 292, + 2983]) + + # Hyperparameters + parser.add_argument('--hidden_channels', type=int, default=16) + parser.add_argument('--fan_out', type=str, default='5,10') + parser.add_argument('--num_layers', type=int, default=2) + parser.add_argument('--learning_rate', type=int, default=0.01) + parser.add_argument('--decay', type=int, default=0.001) + parser.add_argument('--num_workers', type=int, default=4) + parser.add_argument('--batch_size', type=int, default=2048 * 16) + parser.add_argument('--dropout', type=float, default=0.2) + parser.add_argument('--epochs', type=int, default=20) + parser.add_argument( + '--model_type', + type=str, + default='gcn', + choices=[ + 'gat', + 'sage', + 'gcn']) + parser.add_argument('--in_memory', type=int, default=0) + parser.add_argument('--synthetic', type=int, default=0) + parser.add_argument('--device', type=str, default='1') + args = parser.parse_args() + + print("Dataset_size: " + args.dataset_size) + print("Model : " + args.model) + print("Num_classes : " + str(args.num_classes)) + print() + + device = f'cuda:' + args.device if torch.cuda.is_available() else 'cpu' + + dataset = IGL260M_DGL(args) + g = dataset[0] + + best_test_acc, train_acc, test_acc = track_acc(g, args) + + print( + f"Train accuracy: {np.mean(train_acc):.2f} \u00B1 {np.std(train_acc):.2f} \t Best: {np.max(train_acc) * 100:.4f}%") + print( + f"Test accuracy: {np.mean(test_acc):.2f} \u00B1 {np.std(test_acc):.2f} \t Best: {np.max(test_acc) * 100:.4f}%") + print() + print(" -------- For debugging --------- ") + print("Parameters: ", args) + print(g) + print("Train accuracy: ", train_acc) + print("Test accuracy: ", test_acc) + + +if __name__ == '__main__': + main() diff --git a/graph/R-GAT/igbh/tiny/models/utils.py b/graph/R-GAT/igbh/tiny/models/utils.py new file mode 100644 index 000000000..5e9e1a25d --- /dev/null +++ b/graph/R-GAT/igbh/tiny/models/utils.py @@ -0,0 +1,224 @@ +import numpy as np +import torch + + +class IGL260MDataset(object): + def __init__(self, root: str, size: str, in_memory: int, + classes: int, synthetic: int): + self.dir = root + self.size = size + self.synthetic = synthetic + self.in_memory = in_memory + self.num_classes = classes + self.__meta__ = torch.load(osp.join(self.dir, self.size, 'meta.pt')) + + self.num_features = self.__meta__['paper']['emb_dim'] + self.num_nodes = self.__meta__['paper']['num_node'] + self.num_edges = self.__meta__['cites']['num_edge'] + + @property + def paper_feat(self) -> np.ndarray: + if self.synthetic: + return np.random((self.num_nodes, self.num_edges)) + + path = osp.join( + self.dir, + self.size, + 'processed', + 'paper', + 'node_feat.npy') + if self.in_memory: + return np.load(path) + else: + return np.load(path, mmap_mode='r') + + @property + def paper_label(self) -> np.ndarray: + if self.num_classes == 19: + path = osp.join( + self.dir, + self.size, + 'processed', + 'paper', + 'node_label_19.npy') + else: + path = osp.join( + self.dir, + self.size, + 'processed', + 'paper', + 'node_label_2K.npy') + if self.in_memory: + return np.load(path) + else: + return np.load(path, mmap_mode='r') + + @property + def paper_edge(self) -> np.ndarray: + path = osp.join( + self.dir, + self.size, + 'processed', + 'paper__cites__paper', + 'edge_index.npy') + if self.in_memory: + return np.load(path) + else: + return np.load(path, mmap_mode='r') + + +def compute_acc(pred, labels): + """ + Compute the accuracy of prediction given the labels. + """ + labels = labels.long() + return (torch.argmax(pred, dim=1) == labels).float().sum() / len(pred) + + +def evaluate(model, g, inputs, labels, val_nid, batch_size, device): + """ + Evaluate the model on the validation set specified by ``val_nid``. + g : The entire graph. + inputs : The features of all the nodes. + labels : The labels of all the nodes. + val_nid : the node Ids for validation. + batch_size : Number of nodes to compute at the same time. + device : The GPU device to evaluate on. + """ + model.eval() + with torch.no_grad(): + pred = model.inference(g, inputs, batch_size, device) + model.train() + return compute_acc(pred[val_nid], labels[val_nid]) + + +def load_subtensor(g, seeds, input_nodes, device): + """ + Copys features and labels of a set of nodes onto GPU. + """ + batch_inputs = g.ndata['features'][input_nodes].to(device) + batch_labels = g.ndata['labels'][seeds].to(device) + return batch_inputs, batch_labels + + +def track_acc(g, args): + train_accuracy = [] + test_accuracy = [] + g.ndata['features'] = g.ndata['feat'] + g.ndata['labels'] = g.ndata['label'] + in_feats = g.ndata['features'].shape[1] + n_classes = args.num_classes + + # Create csr/coo/csc formats before launching training processes with multi-gpu. + # This avoids creating certain formats in each sub-process, which saves + # momory and CPU. + g.create_formats_() + + num_epochs = args.epochs + num_hidden = args.hidden_channels + num_layers = args.num_layers + fan_out = args.fan_out + batch_size = args.batch_size + lr = args.learning_rate + dropout = args.dropout + num_workers = args.num_workers + + train_nid = torch.nonzero(g.ndata['train_mask'], as_tuple=True)[0] + + # Create PyTorch DataLoader for constructing blocks + sampler = dgl.dataloading.MultiLayerNeighborSampler( + [int(fanout) for fanout in fan_out.split(',')]) + + dataloader = dgl.dataloading.NodeDataLoader( + g, + train_nid, + sampler, + batch_size=batch_size, + shuffle=True, + drop_last=False, + num_workers=num_workers) + + # Define model and optimizer + if args.model_type == 'gcn': + model = GCN(in_feats, num_hidden, n_classes, 1, F.relu, dropout) + if args.model_type == 'sage': + model = SAGE( + in_feats, + num_hidden, + n_classes, + num_layers, + F.relu, + dropout, + 'gcn') + if args.model_type == 'gat': + model = GAT(in_feats, num_hidden, n_classes, num_layers, 2, F.relu) + + model = model.to(device) + loss_fcn = nn.CrossEntropyLoss() + loss_fcn = loss_fcn.to(device) + optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=args.decay) + + # Training loop + avg = 0 + best_test_acc = 0 + log_every = 1 + training_start = time.time() + for epoch in (range(num_epochs)): + # Loop over the dataloader to sample the computation dependency graph as a list of + # blocks. + epoch_loss = 0 + gpu_mem_alloc = 0 + epoch_start = time.time() + for step, (input_nodes, seeds, blocks) in (enumerate(dataloader)): + # Load the input features as well as output labels + # batch_inputs, batch_labels = load_subtensor(g, seeds, input_nodes, device) + blocks = [block.int().to(device) for block in blocks] + batch_inputs = blocks[0].srcdata['features'] + batch_labels = blocks[-1].dstdata['labels'] + + # Compute loss and prediction + batch_pred = model(blocks, batch_inputs) + loss = loss_fcn(batch_pred, batch_labels) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + epoch_loss += loss.detach() + + gpu_mem_alloc += ( + torch.cuda.max_memory_allocated() / 1000000 + if torch.cuda.is_available() + else 0 + ) + + train_g = g + train_nid = torch.nonzero( + train_g.ndata['train_mask'], as_tuple=True)[0] + train_acc = evaluate( + model, train_g, train_g.ndata['features'], train_g.ndata['labels'], train_nid, batch_size, device) + + test_g = g + test_nid = torch.nonzero( + test_g.ndata['test_mask'], as_tuple=True)[0] + test_acc = evaluate( + model, test_g, test_g.ndata['features'], test_g.ndata['labels'], test_nid, batch_size, device) + + if test_acc.item() > best_test_acc: + best_test_acc = test_acc.item() + tqdm.write( + "Epoch {:05d} | Loss {:.4f} | Train Acc {:.4f} | Test Acc {:.4f} | Time {:.2f}s | GPU {:.1f} MB".format( + epoch, + epoch_loss, + train_acc.item(), + test_acc.item(), + time.time() - epoch_start, + gpu_mem_alloc + ) + ) + test_accuracy.append(test_acc.item()) + train_accuracy.append(train_acc.item()) + torch.save(model.state_dict(), args.modelpath) + print() + print("Total time taken: ", time.time() - training_start) + + return best_test_acc, train_accuracy, test_accuracy diff --git a/upcomming_benchmarks/graph/R-GAT/main.py b/graph/R-GAT/main.py similarity index 86% rename from upcomming_benchmarks/graph/R-GAT/main.py rename to graph/R-GAT/main.py index a7697481a..d76ea090d 100644 --- a/upcomming_benchmarks/graph/R-GAT/main.py +++ b/graph/R-GAT/main.py @@ -23,6 +23,8 @@ import dataset import igbh +import dgl_utilities.feature_fetching as dgl_igbh + logging.basicConfig(level=logging.INFO) log = logging.getLogger("main") @@ -31,68 +33,123 @@ MILLI_SEC = 1000 SUPPORTED_DATASETS = { - "igbh-tiny": ( + "igbh-glt-tiny": ( igbh.IGBH, dataset.preprocess, igbh.PostProcessIGBH(), {"dataset_size": "tiny", "use_label_2K": True}, ), - "igbh-small": ( + "igbh-glt-small": ( igbh.IGBH, dataset.preprocess, igbh.PostProcessIGBH(), {"dataset_size": "small", "use_label_2K": True}, ), - "igbh-medium": ( + "igbh-glt-medium": ( igbh.IGBH, dataset.preprocess, igbh.PostProcessIGBH(), {"dataset_size": "medium", "use_label_2K": True}, ), - "igbh-large": ( + "igbh-glt-large": ( igbh.IGBH, dataset.preprocess, igbh.PostProcessIGBH(), {"dataset_size": "large", "use_label_2K": True}, ), - "igbh": ( + "igbh-glt": ( igbh.IGBH, dataset.preprocess, igbh.PostProcessIGBH(), {"dataset_size": "full", "use_label_2K": True}, ), + "igbh-dgl-tiny": ( + dgl_igbh.IGBH, + dataset.preprocess, + igbh.PostProcessIGBH(), + {"dataset_size": "tiny", "use_label_2K": True}, + ), + "igbh-dgl-small": ( + dgl_igbh.IGBH, + dataset.preprocess, + igbh.PostProcessIGBH(), + {"dataset_size": "small", "use_label_2K": True}, + ), + "igbh-dgl-medium": ( + dgl_igbh.IGBH, + dataset.preprocess, + igbh.PostProcessIGBH(), + {"dataset_size": "medium", "use_label_2K": True}, + ), + "igbh-dgl-large": ( + dgl_igbh.IGBH, + dataset.preprocess, + igbh.PostProcessIGBH(), + {"dataset_size": "large", "use_label_2K": True}, + ), + "igbh-dgl": ( + dgl_igbh.IGBH, + dataset.preprocess, + igbh.PostProcessIGBH(), + {"dataset_size": "full", "use_label_2K": True}, + ), } SUPPORTED_PROFILES = { "defaults": { - "dataset": "igbh-tiny", - "backend": "pytorch", + "dataset": "igbh-glt-tiny", + "backend": "glt", + "model-name": "rgat", + }, + "debug-glt": { + "dataset": "igbh-glt-tiny", + "backend": "glt", + "model-name": "rgat", + }, + "rgat-glt-small": { + "dataset": "igbh-glt-small", + "backend": "glt", + "model-name": "rgat", + }, + "rgat-glt-medium": { + "dataset": "igbh-glt-medium", + "backend": "glt", + "model-name": "rgat", + }, + "rgat-glt-large": { + "dataset": "igbh-glt-large", + "backend": "glt", + "model-name": "rgat", + }, + "rgat-glt-full": { + "dataset": "igbh-glt", + "backend": "glt", "model-name": "rgat", }, - "debug": { - "dataset": "igbh-tiny", - "backend": "pytorch", + "debug-dgl": { + "dataset": "igbh-dgl-tiny", + "backend": "dgl", "model-name": "rgat", }, - "rgat-pytorch-small": { - "dataset": "igbh-small", - "backend": "pytorch", + "rgat-dgl-small": { + "dataset": "igbh-dgl-small", + "backend": "dgl", "model-name": "rgat", }, - "rgat-pytorch-medium": { - "dataset": "igbh-medium", - "backend": "pytorch", + "rgat-dgl-medium": { + "dataset": "igbh-dgl-medium", + "backend": "dgl", "model-name": "rgat", }, - "rgat-pytorch-large": { - "dataset": "igbh-large", - "backend": "pytorch", + "rgat-dgl-large": { + "dataset": "igbh-dgl-large", + "backend": "dgl", "model-name": "rgat", }, - "rgat-pytorch-full": { - "dataset": "igbh", - "backend": "pytorch", + "rgat-dgl-full": { + "dataset": "igbh-dgl", + "backend": "dgl", "model-name": "rgat", }, } @@ -226,10 +283,12 @@ def get_args(): def get_backend(backend, **kwargs): - if backend == "pytorch": - from backend_pytorch import BackendPytorch - - backend = BackendPytorch(**kwargs) + if backend == "glt": + from backend_glt import BackendGLT + backend = BackendGLT(**kwargs) + elif backend == "dgl": + from backend_dgl import BackendDGL + backend = BackendDGL(**kwargs) else: raise ValueError("unknown backend: " + backend) return backend @@ -380,7 +439,7 @@ def main(): device=args.device, ckpt_path=args.model_path, batch_size=args.max_batchsize, - igbh_dataset=ds.igbh_dataset, + igbh=ds, layout=args.layout, ) @@ -425,7 +484,7 @@ def main(): count = ds.get_item_count() # warmup - warmup_samples = torch.Tensor([0]).to(torch.int64).to(backend.device) + warmup_samples = torch.Tensor([0]).to(torch.int64) for i in range(5): _ = backend.predict(warmup_samples) diff --git a/upcomming_benchmarks/graph/R-GAT/requirements.txt b/graph/R-GAT/requirements.txt similarity index 90% rename from upcomming_benchmarks/graph/R-GAT/requirements.txt rename to graph/R-GAT/requirements.txt index 062b710cd..8e6e7276f 100644 --- a/upcomming_benchmarks/graph/R-GAT/requirements.txt +++ b/graph/R-GAT/requirements.txt @@ -2,7 +2,7 @@ colorama==0.4.6 tqdm==4.66.4 requests==2.32.2 torch==2.1.0 -dgl==2.1.0 +torchdata==0.7.0 pybind11==2.12.0 PyYAML==6.0.1 pydantic==2.7.1 diff --git a/upcomming_benchmarks/graph/R-GAT/rgnn.py b/graph/R-GAT/rgnn.py similarity index 100% rename from upcomming_benchmarks/graph/R-GAT/rgnn.py rename to graph/R-GAT/rgnn.py diff --git a/upcomming_benchmarks/graph/R-GAT/tools/compress_graph.py b/graph/R-GAT/tools/compress_graph.py similarity index 100% rename from upcomming_benchmarks/graph/R-GAT/tools/compress_graph.py rename to graph/R-GAT/tools/compress_graph.py diff --git a/upcomming_benchmarks/graph/R-GAT/tools/download_igbh_full.sh b/graph/R-GAT/tools/download_igbh_full.sh similarity index 100% rename from upcomming_benchmarks/graph/R-GAT/tools/download_igbh_full.sh rename to graph/R-GAT/tools/download_igbh_full.sh diff --git a/upcomming_benchmarks/graph/R-GAT/tools/download_igbh_test.py b/graph/R-GAT/tools/download_igbh_test.py similarity index 100% rename from upcomming_benchmarks/graph/R-GAT/tools/download_igbh_test.py rename to graph/R-GAT/tools/download_igbh_test.py diff --git a/upcomming_benchmarks/graph/R-GAT/tools/format_model.py b/graph/R-GAT/tools/format_model.py similarity index 100% rename from upcomming_benchmarks/graph/R-GAT/tools/format_model.py rename to graph/R-GAT/tools/format_model.py diff --git a/upcomming_benchmarks/graph/R-GAT/tools/split_seeds.py b/graph/R-GAT/tools/split_seeds.py similarity index 100% rename from upcomming_benchmarks/graph/R-GAT/tools/split_seeds.py rename to graph/R-GAT/tools/split_seeds.py diff --git a/upcomming_benchmarks/graph/R-GAT/user.conf b/graph/R-GAT/user.conf similarity index 100% rename from upcomming_benchmarks/graph/R-GAT/user.conf rename to graph/R-GAT/user.conf