diff --git a/CHANGELOG.md b/CHANGELOG.md index dd0833194744..aa2dbefb272c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added Kùzu remote backend examples ([#7298](https://github.com/pyg-team/pytorch_geometric/pull/7298)) - Fixed tracing of `add_self_loops` for a dynamic number of nodes ([#7330](https://github.com/pyg-team/pytorch_geometric/pull/7330)) - Added an optional `add_pad_mask` argument to the `Pad` transform ([#7339](https://github.com/pyg-team/pytorch_geometric/pull/7339)) - Added `keep_inter_cluster_edges` option to `ClusterData` to support inter-subgraph edge connections when doing graph partitioning ([#7326](https://github.com/pyg-team/pytorch_geometric/pull/7326)) diff --git a/docs/source/advanced/remote.rst b/docs/source/advanced/remote.rst index bb04394d7107..095ef463e2bb 100644 --- a/docs/source/advanced/remote.rst +++ b/docs/source/advanced/remote.rst @@ -107,7 +107,8 @@ An example usage of the interface is shown below: assert torch.equal(row, edge_index[0]) assert torch.equal(col, edge_index[1]) -Common implementations of the :class:`~torch_geometric.data.GraphStore` are graph databases, *e.g.*, :obj:`Neo4j`, :obj:`TigerGraph`, :obj:`ArangoDB` are all viable performant options. +Common implementations of the :class:`~torch_geometric.data.GraphStore` are graph databases, *e.g.*, :obj:`Neo4j`, :obj:`TigerGraph`, :obj:`ArangoDB`, :obj:`Kùzu` are all viable performant options. +We provide an example of using :pyg:`PyG` in combination with the :obj:`Kùzu` database `here __`. A graph sampler is tightly coupled to the given :class:`~torch_geometric.data.GraphStore`, and operates on the :class:`~torch_geometric.data.GraphStore` to produce sampled subgraphs from input nodes. Different sampling algorithms are implemented behind the :class:`torch_geometric.sampler.BaseSampler` interface. diff --git a/docs/source/external/resources.rst b/docs/source/external/resources.rst index bd437e376549..cd62497f9436 100644 --- a/docs/source/external/resources.rst +++ b/docs/source/external/resources.rst @@ -38,3 +38,5 @@ External Resources * Amitoz Azad: **Primal-Dual Algorithm for Total Variation Processing on Graphs** [`Jupyter `__] * Manan Goel: **Recommending Amazon Products using Graph Neural Networks in** :pyg:`null` **PyTorch Geometric** [:wandb:`null` `W&B Report `__] + +* Kùzu: **Remote Backend for** :pyg:`null` **PyTorch Geometric** [:colab:`null` `Colab `__] diff --git a/examples/kuzu/README.md b/examples/kuzu/README.md new file mode 100644 index 000000000000..298baf8f9493 --- /dev/null +++ b/examples/kuzu/README.md @@ -0,0 +1,38 @@ +# Using Kùzu as a Remote Backend for PyG + +[Kùzu](https://kuzudb.com/) is an in-process property graph database management system built for query speed and scalability. +It provides an integration with PyG via the [remote backend interface](https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html) of PyG. +The Python API of Kùzu outputs a [`torch_geometric.data.FeatureStore`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.FeatureStore.html) and a [`torch_geometric.data.GraphStore`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.GraphStore.html) that can be plugged directly into existing familiar PyG interfaces such as [`NeighborLoader`](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/loader/neighbor_loader.html) and enables training GNNs directly on graphs stored in Kùzu. +This is particularly useful if you would like to train graphs that don't fit on your CPU's memory. + +## Installation + +You can install Kùzu as follows: + +```bash +pip install kuzu +``` + +## Usage + +The API and design documentation of Kùzu can be found at [https://kuzudb.com/docs/](https://kuzudb.com/docs/). + +## Examples + +We provide the following examples to showcase the usage of Kùzu remote backend within PyG: + +### PubMed + + + Open In Colab + + +The PubMed example is hosted on [Google Colab](https://colab.research.google.com/drive/12fOSqPm1HQTz_m9caRW7E_92vaeD9xq6). +In this example, we work on a small dataset for demonstrative purposes. +The [PubMed](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.Planetoid.html) dataset consists of 19,717 papers as nodes and 88,648 citation relationships between them. + +### `papers_100M` + +This example shows how to use the remote backend feature of Kùzu to work with a large graph of papers and citations on a single machine. +The data used in this example is `ogbn-papers100M` from the [Open Graph Benchmark](https://ogb.stanford.edu/). +The dataset contains approximately 111 million nodes and 1.6 billion edges. diff --git a/examples/kuzu/papers_100M/README.md b/examples/kuzu/papers_100M/README.md new file mode 100644 index 000000000000..c23bc2a972f8 --- /dev/null +++ b/examples/kuzu/papers_100M/README.md @@ -0,0 +1,16 @@ +# `papers_100M` Example + +This example shows how to use the remote backend feature of [Kùzu](https://kuzudb.com) to work with a large graph of papers and citations on a single machine. +The data used in this example is `ogbn-papers100M` from the [Open Graph Benchmark](https://ogb.stanford.edu/). +The dataset contains approximately 100 million nodes and 1.6 billion edges. + +## Prepare the data + +1. Download the dataset from [`http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip`](http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip) and put the `*.zip` file into this directory. +2. Run `python prepare_data.py`. + The script will automatically extract the data and convert it to the format that Kùzu can read. + A Kùzu database instance is then created under `papers_100M` and the data is loaded into the it. + +## Train a Model + +Afterwards, run `python train.py` to train a three-layer [`GraphSAGE`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphSAGE.html) model on this dataset. diff --git a/examples/kuzu/papers_100M/prepare_data.py b/examples/kuzu/papers_100M/prepare_data.py new file mode 100644 index 000000000000..a4892a6df895 --- /dev/null +++ b/examples/kuzu/papers_100M/prepare_data.py @@ -0,0 +1,54 @@ +from multiprocessing import cpu_count +from os import path +from zipfile import ZipFile + +import kuzu +import numpy as np +from tqdm import tqdm + +with ZipFile("papers100M-bin.zip", 'r') as papers100M_zip: + print('Extracting papers100M-bin.zip...') + papers100M_zip.extractall() + +with ZipFile("papers100M-bin/raw/data.npz", 'r') as data_zip: + print('Extracting data.npz...') + data_zip.extractall() + +with ZipFile("papers100M-bin/raw/node-label.npz", 'r') as node_label_zip: + print('Extracting node-label.npz...') + node_label_zip.extractall() + +print("Converting edge_index to CSV...") +edge_index = np.load('edge_index.npy', mmap_mode='r') +csvfile = open('edge_index.csv', 'w') +csvfile.write('src,dst\n') +for i in tqdm(range(edge_index.shape[1])): + csvfile.write(str(edge_index[0, i]) + ',' + str(edge_index[1, i]) + '\n') +csvfile.close() + +print("Generating IDs for nodes...") +node_year = np.load('node_year.npy', mmap_mode='r') +length = node_year.shape[0] +ids = np.arange(length) +np.save('ids.npy', ids) + +ids_path = path.abspath(path.join('.', 'ids.npy')) +edge_index_path = path.abspath(path.join('.', 'edge_index.csv')) +node_label_path = path.abspath(path.join('.', 'node_label.npy')) +node_feature_path = path.abspath(path.join('.', 'node_feat.npy')) +node_year_path = path.abspath(path.join('.', 'node_year.npy')) + +print("Creating Kùzu database...") +db = kuzu.Database('papers100M') +conn = kuzu.Connection(db, num_threads=cpu_count()) +print("Creating Kùzu tables...") +conn.execute( + "CREATE NODE TABLE paper(id INT64, x FLOAT[128], year INT64, y FLOAT, " + "PRIMARY KEY (id));") +conn.execute("CREATE REL TABLE cites(FROM paper TO paper, MANY_MANY);") +print("Copying nodes to Kùzu tables...") +conn.execute('COPY paper FROM ("%s", "%s", "%s", "%s") BY COLUMN;' % + (ids_path, node_feature_path, node_year_path, node_label_path)) +print("Copying edges to Kùzu tables...") +conn.execute('COPY cites FROM "%s";' % (edge_index_path)) +print("All done!") diff --git a/examples/kuzu/papers_100M/train.py b/examples/kuzu/papers_100M/train.py new file mode 100644 index 000000000000..5b3da061eb79 --- /dev/null +++ b/examples/kuzu/papers_100M/train.py @@ -0,0 +1,119 @@ +import multiprocessing as mp +import os.path as osp + +import kuzu +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +from torch_geometric.loader import NeighborLoader +from torch_geometric.nn import MLP, BatchNorm, SAGEConv + +NUM_EPOCHS = 1 +LOADER_BATCH_SIZE = 1024 + +print('Batch size:', LOADER_BATCH_SIZE) +print('Number of epochs:', NUM_EPOCHS) + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print('Using device:', device) + +# Load the train set: +train_path = osp.join('.', 'papers100M-bin', 'split', 'time', 'train.csv.gz') +train_df = pd.read_csv( + osp.abspath(train_path), + compression='gzip', + header=None, +) +input_nodes = torch.tensor(train_df[0].values, dtype=torch.long) + +######################################################################## +# The below code sets up the remote backend of Kùzu for PyG. +# Please refer to: https://kuzudb.com/docs/client-apis/python-api/overview.html +# for how to use the Python API of Kùzu. +######################################################################## + +# The buffer pool size of Kùzu is set to 40GB. You can change it to a smaller +# value if you have less memory. +KUZU_BM_SIZE = 40 * 1024**3 + +# Create Kùzu database: +db = kuzu.Database(osp.abspath(osp.join('.', 'papers100M')), KUZU_BM_SIZE) + +# Get remote backend for PyG: +feature_store, graph_store = db.get_torch_geometric_remote_backend( + mp.cpu_count()) + +# Plug the graph store and feature store into the `NeighborLoader`. +# Note that `filter_per_worker` is set to `False`. This is because the Kùzu +# database is already using multi-threading to scan the features in parallel +# and the database object is not fork-safe. +loader = NeighborLoader( + data=(feature_store, graph_store), + num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]}, + batch_size=LOADER_BATCH_SIZE, + input_nodes=('paper', input_nodes), + num_workers=4, + filter_per_worker=False, +) + + +class GraphSAGE(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, num_layers, + dropout=0.2): + super().__init__() + + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + + self.convs.append(SAGEConv(in_channels, hidden_channels)) + self.bns.append(BatchNorm(hidden_channels)) + for i in range(1, num_layers): + self.layers.append(SAGEConv(hidden_channels, hidden_channels)) + self.bns.append(BatchNorm(hidden_channels)) + + self.mlp = MLP( + in_channels=in_channels + num_layers * hidden_channels, + hidden_channels=2 * out_channels, + out_channels=out_channels, + num_layers=2, + norm='batch_norm', + act='leaky_relu', + ) + + def forward(self, x, edge_index): + x = F.dropout(x, p=self.dropout, training=self.training) + xs = [x] + for conv, norm in zip(self.convs, self.norms): + x = conv(x, edge_index) + x = norm(x) + x = x.relu() + x = F.dropout(x, p=self.dropout, training=self.training) + xs.append(x) + return self.mlp(torch.cat(xs, dim=-1)) + + +model = GraphSAGE(in_channels=128, hidden_channels=1024, out_channels=172, + num_layers=3).to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + +for epoch in range(1, NUM_EPOCHS + 1): + total_loss = total_examples = 0 + for batch in tqdm(loader): + batch = batch.to(device) + batch_size = batch['paper'].batch_size + + optimizer.zero_grad() + out = model(batch.x, batch.edge_index)[:batch_size] + y = batch.y[:batch_size].long().view(-1) + loss = F.cross_entropy_loss(out, y) + + loss.backward() + optimizer.step() + + total_loss += float(loss) * y.numel() + total_examples += y.numel() + + print(f'Epoch: {epoch:02d}, Loss: {total_loss / total_examples:.4f}')