Skip to content

Commit

Permalink
[Operator] add scatter_max (#308)
Browse files Browse the repository at this point in the history
* new scatter_max

* [Bugfix] Update loading datasets (#307)

* Fix scatter max

* Add max-aggr sage

* Update submodule urls

* Update submodule urls

* fix bugs for scatter_max

* use relative path

* fix bugs for multi gpus

* remove redundancy

Co-authored-by: Yukuo Cen <cenyk1230@qq.com>
  • Loading branch information
fishmingyu and cenyk1230 authored Nov 29, 2021
1 parent f6f33c6 commit e576766
Show file tree
Hide file tree
Showing 8 changed files with 501 additions and 349 deletions.
6 changes: 3 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
[submodule "third_party/dgNN"]
ignore = dirty
path = third_party/dgNN
url = https://github.com/dgSPARSE/dgNN
url = ../../dgSPARSE/dgNN
branch = main
[submodule "third_party/actnn"]
ignore = dirty
path = third_party/actnn
url = https://github.com/ucbrise/actnn
url = ../../ucbrise/actnn
branch = main
[submodule "third_party/fastmoe"]
ignore = dirty
path = third_party/fastmoe
url = https://github.com/laekov/fastmoe
url = ../../laekov/fastmoe
branch = master
13 changes: 13 additions & 0 deletions cogdl/layers/sage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ def __call__(self, graph, x):
return x


class MaxAggregator(object):
def __init__(self):
from cogdl.operators.scatter_max import scatter_max

self.scatter_max = scatter_max

def __call__(self, graph, x):
x = self.scatter_max(graph.row_indptr.int(), graph.col_indices.int(), x)
return x


class SAGELayer(nn.Module):
def __init__(
self, in_feats, out_feats, normalize=False, aggr="mean", dropout=0.0, norm=None, activation=None, residual=False
Expand All @@ -35,6 +46,8 @@ def __init__(
self.aggr = MeanAggregator()
elif aggr == "sum":
self.aggr = SumAggregator()
elif aggr == "max":
self.aggr = MaxAggregator()
else:
raise NotImplementedError

Expand Down
37 changes: 37 additions & 0 deletions cogdl/operators/scatter_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import numpy as np
import torch
from torch.utils.cpp_extension import load

path = os.path.join(os.path.dirname(__file__))

# SPMM

try:
spmm_max = load(
name="scatter_max",
sources=[os.path.join(path, "scatter_max/scatter_max.cc"), os.path.join(path, "scatter_max/scatter_max.cu")],
verbose=True,
)

def scatter_max(rowptr, colind, feat):
return ScatterMaxFunction.apply(rowptr, colind, feat)


except Exception:
spmm_max = None


class ScatterMaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, rowptr, colind, feat):
out, max_id = spmm_max.scatter_max_fp(rowptr, colind, feat)
ctx.save_for_backward(max_id)
return out

@staticmethod
def backward(ctx, grad):
grad = grad.contiguous()
max_id = ctx.saved_tensors[0]
out = spmm_max.scatter_max_bp(grad, max_id)
return None, None, out
38 changes: 38 additions & 0 deletions cogdl/operators/scatter_max/scatter_max.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include <vector>

void assertTensor(torch::Tensor &T, c10::ScalarType type) {
assert(T.is_contiguous());
assert(T.device().type() == torch::kCUDA);
assert(T.dtype() == type);
}

std::vector<torch::Tensor> scatter_max_fp_cuda(torch::Tensor rowptr,
torch::Tensor colind,
torch::Tensor node_feature);

torch::Tensor scatter_max_bp_cuda(torch::Tensor node_feature,
torch::Tensor max_mask, long num_nodes);

std::vector<torch::Tensor> scatter_max(torch::Tensor rowptr,
torch::Tensor colind,
torch::Tensor node_feature) {
assertTensor(rowptr, torch::kInt32);
assertTensor(colind, torch::kInt32);
assertTensor(node_feature, torch::kFloat32);
return scatter_max_fp_cuda(rowptr, colind, node_feature);
}

torch::Tensor scatter_max_bp(torch::Tensor node_feature,
torch::Tensor max_mask) {
assertTensor(node_feature, torch::kFloat32);
assertTensor(max_mask, torch::kInt32);
return scatter_max_bp_cuda(node_feature, max_mask);
}

PYBIND11_MODULE(scatter_max, m) {
m.doc() = "scatter max kernel";
m.def("scatter_max_fp", &scatter_max, "scatter max forward");
m.def("scatter_max_bp", &scatter_max_bp, "scatter max backward");
}
75 changes: 75 additions & 0 deletions cogdl/operators/scatter_max/scatter_max.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include <cuda.h>
#include <torch/types.h>
#include <vector>

__global__ void scatter_max_forward(const int *A_indptr, const int *A_indices,
const float *B, float *C, int *max_mask) {
int rid = blockDim.y * blockIdx.x + threadIdx.y;
int m = gridDim.x;
int k = blockDim.x;
if (rid < m) {
int lb = A_indptr[rid];
int hb = A_indptr[(rid + 1)];
int stride = hb - lb;
int offset;
int max_id;
float acc = (stride > 0) ? FLT_MIN : 0;
for (int ptr = lb; ptr < hb; ptr++) {
int cid = A_indices[ptr];
offset = cid * k + threadIdx.x;
if (acc < B[offset]) {
acc = B[offset];
max_id = cid;
}
}
C[(rid * k + threadIdx.x)] = acc;
max_mask[(rid * k + threadIdx.x)] = max_id;
}
}

__global__ void scatter_max_backward(const float *grad, float *out,
int *max_mask) {
int rid = blockDim.y * blockIdx.x + threadIdx.y;
int m = gridDim.x;
int k = blockDim.x;
if (rid < m) {
int offset = rid * k + threadIdx.x;
int max_id;
max_id = max_mask[offset]; // max mapping
float grad_tmp = grad[offset];
atomicAdd(&out[max_id * k + threadIdx.x], grad_tmp);
}
}

std::vector<torch::Tensor> scatter_max_fp_cuda(torch::Tensor rowptr,
torch::Tensor colind,
torch::Tensor node_feature) {
const long m = rowptr.size(0) - 1;
const long k = node_feature.size(1);
auto devid = node_feature.device().index();
auto optionsI =
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, devid);
auto optionsF =
torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid);
auto max_mask = torch::empty({m, k}, optionsI);
auto out = torch::empty({m, k}, optionsF);
scatter_max_forward<<<m, k>>>(rowptr.data_ptr<int>(), colind.data_ptr<int>(),
node_feature.data_ptr<float>(),
out.data_ptr<float>(),
max_mask.data_ptr<int>());
return {out, max_mask};
}

torch::Tensor scatter_max_bp_cuda(torch::Tensor node_feature,
torch::Tensor max_mask) {
const long m = node_feature.size(0);
const long k = node_feature.size(1);
auto devid = node_feature.device().index();
auto options =
torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid);
auto out = torch::empty({m, k}, options);
scatter_max_backward<<<m, k>>>(node_feature.data_ptr<float>(),
out.data_ptr<float>(),
max_mask.data_ptr<int>());
return out;
}
Loading

0 comments on commit e576766

Please sign in to comment.