Skip to content

Commit

Permalink
[Feature] Jittor gcn example (#350)
Browse files Browse the repository at this point in the history
* Jittor gcn example

* Fix format
  • Loading branch information
cenyk1230 authored Jun 1, 2022
1 parent 92ad2a2 commit 0b4f1d3
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 0 deletions.
1 change: 1 addition & 0 deletions cogdl/layers/jittor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .gcn_layer import GCNLayer
58 changes: 58 additions & 0 deletions cogdl/layers/jittor/gcn_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import jittor as jt
from jittor import nn, Module, init

from cogdl.operators.jt_spmm import spmm


class GCNLayer(Module):
def __init__(
self, in_features, out_features, dropout=0.0, activation=None, residual=False, norm=None, bias=True, **kwargs
):
super(GCNLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.linear = nn.Linear(in_features, out_features, bias=bias)
if dropout > 0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
if residual:
self.residual = nn.Linear(in_features, out_features)
else:
self.residual = None

if activation is not None and activation == "relu":
self.act = nn.ReLU()
else:
self.act = None

if norm is not None:
if norm == "batchnorm":
self.norm = nn.BatchNorm1d(out_features)
elif norm == "layernorm":
self.norm = nn.LayerNorm(out_features)
else:
raise NotImplementedError
else:
self.norm = None

self.reset_parameters()

def reset_parameters(self):
init.xavier_uniform_(self.linear.weight)

def execute(self, graph, x):
support = self.linear(x)
out = spmm(graph, support)

if self.norm is not None:
out = self.norm(out)
if self.act is not None:
out = self.act(out)

if self.residual is not None:
out = out + self.residual(x)
if self.dropout is not None:
out = self.dropout(out)
return out

48 changes: 48 additions & 0 deletions cogdl/operators/jt_spmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import jittor as jt

from jittor import Function
from jittor.compiler import compile_torch_extensions

jt.flags.use_cuda = 1
cached_op = {"csr_spmm": None}


def tensor2jit(x):
return jt.array(x.cpu().numpy())


def init_spmm_ops():
if cached_op["csr_spmm"] is None:
op_path = os.path.abspath(__file__)
spmm_path = os.path.join(os.path.dirname(op_path), "spmm/spmm.cpp")
spmm_cu_path = os.path.join(os.path.dirname(op_path), "spmm/spmm_kernel.cu")
compile_torch_extensions("spmm", [spmm_path, spmm_cu_path], 1, 1)
from spmm import csr_spmm

cached_op["csr_spmm"] = csr_spmm


def spmm(graph, x):
row_ptr, col_indices = graph.row_indptr, graph.col_indices
csr_data = graph.edge_weight
spmm = SPMM()
x = spmm(tensor2jit(row_ptr.int()), tensor2jit(col_indices.int()), x, tensor2jit(csr_data))
return x


class SPMM(Function):
def execute(self, rowptr, colind, feat, edge_weight_csr=None):
init_spmm_ops()
self.csr_spmm = cached_op["csr_spmm"]

out = self.csr_spmm(rowptr, colind, edge_weight_csr, feat)
self.backward_csc = (rowptr, colind, edge_weight_csr)
return out

def grad(self, grad_out):
rowptr, colind, edge_weight_csr = self.backward_csc
colptr, rowind, edge_weight_csc = rowptr, colind, edge_weight_csr
grad_feat = self.csr_spmm(colptr, rowind, edge_weight_csc, grad_out)

return None, None, grad_feat, None
63 changes: 63 additions & 0 deletions examples/jittor/gcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import jittor as jt
jt.flags.use_cuda = 1

from jittor import nn, Module, init
from jittor import optim
from jittor.contrib import slice_var_index

from tqdm import tqdm

from cogdl.layers.jittor import GCNLayer
from cogdl.datasets.planetoid_data import CoraDataset


def tensor2jit(x):
return jt.array(x.cpu().numpy())

class GCN(Module):
def __init__(self, in_feats, hidden_size, out_feats, dropout=0.5):
super(GCN, self).__init__()
self.in_feats = in_feats
self.conv1 = GCNLayer(in_feats, hidden_size, dropout=dropout, activation="relu")
self.conv2 = GCNLayer(hidden_size, out_feats)

def execute(self, graph):
graph.sym_norm()
x = tensor2jit(graph.x)
out = self.conv1(graph, x)
out = self.conv2(graph, out)
return out


def train(model, dataset):
graph = dataset[0]

optimizer = nn.AdamW(model.parameters(), lr=0.01)
loss_function = nn.CrossEntropyLoss()

train_mask = tensor2jit(graph.train_mask)
test_mask = tensor2jit(graph.test_mask)
val_mask = tensor2jit(graph.val_mask)
labels = tensor2jit(graph.y)

for epoch in range(100):
model.train()
output = model(graph)
loss = loss_function(output[train_mask], labels[train_mask])
optimizer.step(loss)

model.eval()
with jt.no_grad():
output = model(graph)
pred = output.argmax(1)[0]
train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

print(f"Epoch:{epoch}, loss:{loss:.3f}, val_acc:{val_acc:.3f}, test_acc:{test_acc:.3f}")

if __name__ == "__main__":
dataset = CoraDataset()
model = GCN(in_feats=dataset.num_features, hidden_size=64, out_feats=dataset.num_classes, dropout=0.5)

train(model, dataset)

0 comments on commit 0b4f1d3

Please sign in to comment.