-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Kùzu remote backend examples (#7298)
This PR adds examples of Kùzu's remote backend integration with PyG. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
- Loading branch information
1 parent
4d4c91a
commit 1bc5466
Showing
7 changed files
with
232 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Using Kùzu as a Remote Backend for PyG | ||
|
||
[Kùzu](https://kuzudb.com/) is an in-process property graph database management system built for query speed and scalability. | ||
It provides an integration with PyG via the [remote backend interface](https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html) of PyG. | ||
The Python API of Kùzu outputs a [`torch_geometric.data.FeatureStore`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.FeatureStore.html) and a [`torch_geometric.data.GraphStore`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.GraphStore.html) that can be plugged directly into existing familiar PyG interfaces such as [`NeighborLoader`](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/loader/neighbor_loader.html) and enables training GNNs directly on graphs stored in Kùzu. | ||
This is particularly useful if you would like to train graphs that don't fit on your CPU's memory. | ||
|
||
## Installation | ||
|
||
You can install Kùzu as follows: | ||
|
||
```bash | ||
pip install kuzu | ||
``` | ||
|
||
## Usage | ||
|
||
The API and design documentation of Kùzu can be found at [https://kuzudb.com/docs/](https://kuzudb.com/docs/). | ||
|
||
## Examples | ||
|
||
We provide the following examples to showcase the usage of Kùzu remote backend within PyG: | ||
|
||
### PubMed | ||
|
||
<a target="_blank" href="https://colab.research.google.com/drive/12fOSqPm1HQTz_m9caRW7E_92vaeD9xq6"> | ||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> | ||
</a> | ||
|
||
The PubMed example is hosted on [Google Colab](https://colab.research.google.com/drive/12fOSqPm1HQTz_m9caRW7E_92vaeD9xq6). | ||
In this example, we work on a small dataset for demonstrative purposes. | ||
The [PubMed](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.Planetoid.html) dataset consists of 19,717 papers as nodes and 88,648 citation relationships between them. | ||
|
||
### `papers_100M` | ||
|
||
This example shows how to use the remote backend feature of Kùzu to work with a large graph of papers and citations on a single machine. | ||
The data used in this example is `ogbn-papers100M` from the [Open Graph Benchmark](https://ogb.stanford.edu/). | ||
The dataset contains approximately 111 million nodes and 1.6 billion edges. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# `papers_100M` Example | ||
|
||
This example shows how to use the remote backend feature of [Kùzu](https://kuzudb.com) to work with a large graph of papers and citations on a single machine. | ||
The data used in this example is `ogbn-papers100M` from the [Open Graph Benchmark](https://ogb.stanford.edu/). | ||
The dataset contains approximately 100 million nodes and 1.6 billion edges. | ||
|
||
## Prepare the data | ||
|
||
1. Download the dataset from [`http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip`](http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip) and put the `*.zip` file into this directory. | ||
2. Run `python prepare_data.py`. | ||
The script will automatically extract the data and convert it to the format that Kùzu can read. | ||
A Kùzu database instance is then created under `papers_100M` and the data is loaded into the it. | ||
|
||
## Train a Model | ||
|
||
Afterwards, run `python train.py` to train a three-layer [`GraphSAGE`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphSAGE.html) model on this dataset. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from multiprocessing import cpu_count | ||
from os import path | ||
from zipfile import ZipFile | ||
|
||
import kuzu | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
with ZipFile("papers100M-bin.zip", 'r') as papers100M_zip: | ||
print('Extracting papers100M-bin.zip...') | ||
papers100M_zip.extractall() | ||
|
||
with ZipFile("papers100M-bin/raw/data.npz", 'r') as data_zip: | ||
print('Extracting data.npz...') | ||
data_zip.extractall() | ||
|
||
with ZipFile("papers100M-bin/raw/node-label.npz", 'r') as node_label_zip: | ||
print('Extracting node-label.npz...') | ||
node_label_zip.extractall() | ||
|
||
print("Converting edge_index to CSV...") | ||
edge_index = np.load('edge_index.npy', mmap_mode='r') | ||
csvfile = open('edge_index.csv', 'w') | ||
csvfile.write('src,dst\n') | ||
for i in tqdm(range(edge_index.shape[1])): | ||
csvfile.write(str(edge_index[0, i]) + ',' + str(edge_index[1, i]) + '\n') | ||
csvfile.close() | ||
|
||
print("Generating IDs for nodes...") | ||
node_year = np.load('node_year.npy', mmap_mode='r') | ||
length = node_year.shape[0] | ||
ids = np.arange(length) | ||
np.save('ids.npy', ids) | ||
|
||
ids_path = path.abspath(path.join('.', 'ids.npy')) | ||
edge_index_path = path.abspath(path.join('.', 'edge_index.csv')) | ||
node_label_path = path.abspath(path.join('.', 'node_label.npy')) | ||
node_feature_path = path.abspath(path.join('.', 'node_feat.npy')) | ||
node_year_path = path.abspath(path.join('.', 'node_year.npy')) | ||
|
||
print("Creating Kùzu database...") | ||
db = kuzu.Database('papers100M') | ||
conn = kuzu.Connection(db, num_threads=cpu_count()) | ||
print("Creating Kùzu tables...") | ||
conn.execute( | ||
"CREATE NODE TABLE paper(id INT64, x FLOAT[128], year INT64, y FLOAT, " | ||
"PRIMARY KEY (id));") | ||
conn.execute("CREATE REL TABLE cites(FROM paper TO paper, MANY_MANY);") | ||
print("Copying nodes to Kùzu tables...") | ||
conn.execute('COPY paper FROM ("%s", "%s", "%s", "%s") BY COLUMN;' % | ||
(ids_path, node_feature_path, node_year_path, node_label_path)) | ||
print("Copying edges to Kùzu tables...") | ||
conn.execute('COPY cites FROM "%s";' % (edge_index_path)) | ||
print("All done!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import multiprocessing as mp | ||
import os.path as osp | ||
|
||
import kuzu | ||
import pandas as pd | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from tqdm import tqdm | ||
|
||
from torch_geometric.loader import NeighborLoader | ||
from torch_geometric.nn import MLP, BatchNorm, SAGEConv | ||
|
||
NUM_EPOCHS = 1 | ||
LOADER_BATCH_SIZE = 1024 | ||
|
||
print('Batch size:', LOADER_BATCH_SIZE) | ||
print('Number of epochs:', NUM_EPOCHS) | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
print('Using device:', device) | ||
|
||
# Load the train set: | ||
train_path = osp.join('.', 'papers100M-bin', 'split', 'time', 'train.csv.gz') | ||
train_df = pd.read_csv( | ||
osp.abspath(train_path), | ||
compression='gzip', | ||
header=None, | ||
) | ||
input_nodes = torch.tensor(train_df[0].values, dtype=torch.long) | ||
|
||
######################################################################## | ||
# The below code sets up the remote backend of Kùzu for PyG. | ||
# Please refer to: https://kuzudb.com/docs/client-apis/python-api/overview.html | ||
# for how to use the Python API of Kùzu. | ||
######################################################################## | ||
|
||
# The buffer pool size of Kùzu is set to 40GB. You can change it to a smaller | ||
# value if you have less memory. | ||
KUZU_BM_SIZE = 40 * 1024**3 | ||
|
||
# Create Kùzu database: | ||
db = kuzu.Database(osp.abspath(osp.join('.', 'papers100M')), KUZU_BM_SIZE) | ||
|
||
# Get remote backend for PyG: | ||
feature_store, graph_store = db.get_torch_geometric_remote_backend( | ||
mp.cpu_count()) | ||
|
||
# Plug the graph store and feature store into the `NeighborLoader`. | ||
# Note that `filter_per_worker` is set to `False`. This is because the Kùzu | ||
# database is already using multi-threading to scan the features in parallel | ||
# and the database object is not fork-safe. | ||
loader = NeighborLoader( | ||
data=(feature_store, graph_store), | ||
num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]}, | ||
batch_size=LOADER_BATCH_SIZE, | ||
input_nodes=('paper', input_nodes), | ||
num_workers=4, | ||
filter_per_worker=False, | ||
) | ||
|
||
|
||
class GraphSAGE(nn.Module): | ||
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, | ||
dropout=0.2): | ||
super().__init__() | ||
|
||
self.convs = nn.ModuleList() | ||
self.norms = nn.ModuleList() | ||
|
||
self.convs.append(SAGEConv(in_channels, hidden_channels)) | ||
self.bns.append(BatchNorm(hidden_channels)) | ||
for i in range(1, num_layers): | ||
self.layers.append(SAGEConv(hidden_channels, hidden_channels)) | ||
self.bns.append(BatchNorm(hidden_channels)) | ||
|
||
self.mlp = MLP( | ||
in_channels=in_channels + num_layers * hidden_channels, | ||
hidden_channels=2 * out_channels, | ||
out_channels=out_channels, | ||
num_layers=2, | ||
norm='batch_norm', | ||
act='leaky_relu', | ||
) | ||
|
||
def forward(self, x, edge_index): | ||
x = F.dropout(x, p=self.dropout, training=self.training) | ||
xs = [x] | ||
for conv, norm in zip(self.convs, self.norms): | ||
x = conv(x, edge_index) | ||
x = norm(x) | ||
x = x.relu() | ||
x = F.dropout(x, p=self.dropout, training=self.training) | ||
xs.append(x) | ||
return self.mlp(torch.cat(xs, dim=-1)) | ||
|
||
|
||
model = GraphSAGE(in_channels=128, hidden_channels=1024, out_channels=172, | ||
num_layers=3).to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) | ||
|
||
for epoch in range(1, NUM_EPOCHS + 1): | ||
total_loss = total_examples = 0 | ||
for batch in tqdm(loader): | ||
batch = batch.to(device) | ||
batch_size = batch['paper'].batch_size | ||
|
||
optimizer.zero_grad() | ||
out = model(batch.x, batch.edge_index)[:batch_size] | ||
y = batch.y[:batch_size].long().view(-1) | ||
loss = F.cross_entropy_loss(out, y) | ||
|
||
loss.backward() | ||
optimizer.step() | ||
|
||
total_loss += float(loss) * y.numel() | ||
total_examples += y.numel() | ||
|
||
print(f'Epoch: {epoch:02d}, Loss: {total_loss / total_examples:.4f}') |