Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Kùzu remote backend examples #7298

Merged
merged 8 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added Kùzu remote backend examples ([#7298](https://github.com/pyg-team/pytorch_geometric/pull/7298))
- Fixed tracing of `add_self_loops` for a dynamic number of nodes ([#7330](https://github.com/pyg-team/pytorch_geometric/pull/7330))
- Added an optional `add_pad_mask` argument to the `Pad` transform ([#7339](https://github.com/pyg-team/pytorch_geometric/pull/7339))
- Added `keep_inter_cluster_edges` option to `ClusterData` to support inter-subgraph edge connections when doing graph partitioning ([#7326](https://github.com/pyg-team/pytorch_geometric/pull/7326))
Expand Down
3 changes: 2 additions & 1 deletion docs/source/advanced/remote.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ An example usage of the interface is shown below:
assert torch.equal(row, edge_index[0])
assert torch.equal(col, edge_index[1])
Common implementations of the :class:`~torch_geometric.data.GraphStore` are graph databases, *e.g.*, :obj:`Neo4j`, :obj:`TigerGraph`, :obj:`ArangoDB` are all viable performant options.
Common implementations of the :class:`~torch_geometric.data.GraphStore` are graph databases, *e.g.*, :obj:`Neo4j`, :obj:`TigerGraph`, :obj:`ArangoDB`, :obj:`Kùzu` are all viable performant options.
We provide an example of using :pyg:`PyG` in combination with the :obj:`Kùzu` database `here <https://github.com/pyg-team/pytorch_geometric/tree/master/examples/kuzu>__`.

A graph sampler is tightly coupled to the given :class:`~torch_geometric.data.GraphStore`, and operates on the :class:`~torch_geometric.data.GraphStore` to produce sampled subgraphs from input nodes.
Different sampling algorithms are implemented behind the :class:`torch_geometric.sampler.BaseSampler` interface.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/external/resources.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ External Resources
* Amitoz Azad: **Primal-Dual Algorithm for Total Variation Processing on Graphs** [`Jupyter <https://nbviewer.jupyter.org/github/aGIToz/Graph_Signal_Processing/tree/main>`__]

* Manan Goel: **Recommending Amazon Products using Graph Neural Networks in** :pyg:`null` **PyTorch Geometric** [:wandb:`null` `W&B Report <https://wandb.ai/manan-goel/gnn-recommender/reports/Recommending-Amazon-Products-using-Graph-Neural-Networks-in-PyTorch-Geometric--VmlldzozMTA3MzYw>`__]

* Kùzu: **Remote Backend for** :pyg:`null` **PyTorch Geometric** [:colab:`null` `Colab <https://colab.research.google.com/drive/12fOSqPm1HQTz_m9caRW7E_92vaeD9xq6>`__]
38 changes: 38 additions & 0 deletions examples/kuzu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Using Kùzu as a Remote Backend for PyG

[Kùzu](https://kuzudb.com/) is an in-process property graph database management system built for query speed and scalability.
It provides an integration with PyG via the [remote backend interface](https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html) of PyG.
The Python API of Kùzu outputs a [`torch_geometric.data.FeatureStore`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.FeatureStore.html) and a [`torch_geometric.data.GraphStore`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.GraphStore.html) that can be plugged directly into existing familiar PyG interfaces such as [`NeighborLoader`](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/loader/neighbor_loader.html) and enables training GNNs directly on graphs stored in Kùzu.
This is particularly useful if you would like to train graphs that don't fit on your CPU's memory.

## Installation

You can install Kùzu as follows:

```bash
pip install kuzu
```

## Usage

The API and design documentation of Kùzu can be found at [https://kuzudb.com/docs/](https://kuzudb.com/docs/).

## Examples

We provide the following examples to showcase the usage of Kùzu remote backend within PyG:

### PubMed

<a target="_blank" href="https://colab.research.google.com/drive/12fOSqPm1HQTz_m9caRW7E_92vaeD9xq6">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

The PubMed example is hosted on [Google Colab](https://colab.research.google.com/drive/12fOSqPm1HQTz_m9caRW7E_92vaeD9xq6).
In this example, we work on a small dataset for demonstrative purposes.
The [PubMed](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.Planetoid.html) dataset consists of 19,717 papers as nodes and 88,648 citation relationships between them.

### `papers_100M`

This example shows how to use the remote backend feature of Kùzu to work with a large graph of papers and citations on a single machine.
The data used in this example is `ogbn-papers100M` from the [Open Graph Benchmark](https://ogb.stanford.edu/).
The dataset contains approximately 111 million nodes and 1.6 billion edges.
16 changes: 16 additions & 0 deletions examples/kuzu/papers_100M/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# `papers_100M` Example

This example shows how to use the remote backend feature of [Kùzu](https://kuzudb.com) to work with a large graph of papers and citations on a single machine.
The data used in this example is `ogbn-papers100M` from the [Open Graph Benchmark](https://ogb.stanford.edu/).
The dataset contains approximately 100 million nodes and 1.6 billion edges.

## Prepare the data

1. Download the dataset from [`http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip`](http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip) and put the `*.zip` file into this directory.
2. Run `python prepare_data.py`.
The script will automatically extract the data and convert it to the format that Kùzu can read.
A Kùzu database instance is then created under `papers_100M` and the data is loaded into the it.

## Train a Model

Afterwards, run `python train.py` to train a three-layer [`GraphSAGE`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphSAGE.html) model on this dataset.
54 changes: 54 additions & 0 deletions examples/kuzu/papers_100M/prepare_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from multiprocessing import cpu_count
from os import path
from zipfile import ZipFile

import kuzu
import numpy as np
from tqdm import tqdm

with ZipFile("papers100M-bin.zip", 'r') as papers100M_zip:
print('Extracting papers100M-bin.zip...')
papers100M_zip.extractall()

with ZipFile("papers100M-bin/raw/data.npz", 'r') as data_zip:
print('Extracting data.npz...')
data_zip.extractall()

with ZipFile("papers100M-bin/raw/node-label.npz", 'r') as node_label_zip:
print('Extracting node-label.npz...')
node_label_zip.extractall()

print("Converting edge_index to CSV...")
edge_index = np.load('edge_index.npy', mmap_mode='r')
csvfile = open('edge_index.csv', 'w')
csvfile.write('src,dst\n')
for i in tqdm(range(edge_index.shape[1])):
csvfile.write(str(edge_index[0, i]) + ',' + str(edge_index[1, i]) + '\n')
csvfile.close()

print("Generating IDs for nodes...")
node_year = np.load('node_year.npy', mmap_mode='r')
length = node_year.shape[0]
ids = np.arange(length)
np.save('ids.npy', ids)

ids_path = path.abspath(path.join('.', 'ids.npy'))
edge_index_path = path.abspath(path.join('.', 'edge_index.csv'))
node_label_path = path.abspath(path.join('.', 'node_label.npy'))
node_feature_path = path.abspath(path.join('.', 'node_feat.npy'))
node_year_path = path.abspath(path.join('.', 'node_year.npy'))

print("Creating Kùzu database...")
db = kuzu.Database('papers100M')
conn = kuzu.Connection(db, num_threads=cpu_count())
print("Creating Kùzu tables...")
conn.execute(
"CREATE NODE TABLE paper(id INT64, x FLOAT[128], year INT64, y FLOAT, "
"PRIMARY KEY (id));")
conn.execute("CREATE REL TABLE cites(FROM paper TO paper, MANY_MANY);")
print("Copying nodes to Kùzu tables...")
conn.execute('COPY paper FROM ("%s", "%s", "%s", "%s") BY COLUMN;' %
(ids_path, node_feature_path, node_year_path, node_label_path))
print("Copying edges to Kùzu tables...")
conn.execute('COPY cites FROM "%s";' % (edge_index_path))
print("All done!")
119 changes: 119 additions & 0 deletions examples/kuzu/papers_100M/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import multiprocessing as mp
import os.path as osp

import kuzu
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import MLP, BatchNorm, SAGEConv

NUM_EPOCHS = 1
LOADER_BATCH_SIZE = 1024

print('Batch size:', LOADER_BATCH_SIZE)
print('Number of epochs:', NUM_EPOCHS)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Load the train set:
train_path = osp.join('.', 'papers100M-bin', 'split', 'time', 'train.csv.gz')
train_df = pd.read_csv(
osp.abspath(train_path),
compression='gzip',
header=None,
)
input_nodes = torch.tensor(train_df[0].values, dtype=torch.long)

########################################################################
# The below code sets up the remote backend of Kùzu for PyG.
# Please refer to: https://kuzudb.com/docs/client-apis/python-api/overview.html
# for how to use the Python API of Kùzu.
########################################################################

# The buffer pool size of Kùzu is set to 40GB. You can change it to a smaller
# value if you have less memory.
KUZU_BM_SIZE = 40 * 1024**3

# Create Kùzu database:
db = kuzu.Database(osp.abspath(osp.join('.', 'papers100M')), KUZU_BM_SIZE)

# Get remote backend for PyG:
feature_store, graph_store = db.get_torch_geometric_remote_backend(
mp.cpu_count())

# Plug the graph store and feature store into the `NeighborLoader`.
# Note that `filter_per_worker` is set to `False`. This is because the Kùzu
# database is already using multi-threading to scan the features in parallel
# and the database object is not fork-safe.
loader = NeighborLoader(
data=(feature_store, graph_store),
num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
batch_size=LOADER_BATCH_SIZE,
input_nodes=('paper', input_nodes),
num_workers=4,
filter_per_worker=False,
)


class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
dropout=0.2):
super().__init__()

self.convs = nn.ModuleList()
self.norms = nn.ModuleList()

self.convs.append(SAGEConv(in_channels, hidden_channels))
self.bns.append(BatchNorm(hidden_channels))
for i in range(1, num_layers):
self.layers.append(SAGEConv(hidden_channels, hidden_channels))
self.bns.append(BatchNorm(hidden_channels))

self.mlp = MLP(
in_channels=in_channels + num_layers * hidden_channels,
hidden_channels=2 * out_channels,
out_channels=out_channels,
num_layers=2,
norm='batch_norm',
act='leaky_relu',
)

def forward(self, x, edge_index):
x = F.dropout(x, p=self.dropout, training=self.training)
xs = [x]
for conv, norm in zip(self.convs, self.norms):
x = conv(x, edge_index)
x = norm(x)
x = x.relu()
x = F.dropout(x, p=self.dropout, training=self.training)
xs.append(x)
return self.mlp(torch.cat(xs, dim=-1))


model = GraphSAGE(in_channels=128, hidden_channels=1024, out_channels=172,
num_layers=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(1, NUM_EPOCHS + 1):
total_loss = total_examples = 0
for batch in tqdm(loader):
batch = batch.to(device)
batch_size = batch['paper'].batch_size

optimizer.zero_grad()
out = model(batch.x, batch.edge_index)[:batch_size]
y = batch.y[:batch_size].long().view(-1)
loss = F.cross_entropy_loss(out, y)

loss.backward()
optimizer.step()

total_loss += float(loss) * y.numel()
total_examples += y.numel()

print(f'Epoch: {epoch:02d}, Loss: {total_loss / total_examples:.4f}')