-
Notifications
You must be signed in to change notification settings - Fork 315
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* new scatter_max * [Bugfix] Update loading datasets (#307) * Fix scatter max * Add max-aggr sage * Update submodule urls * Update submodule urls * fix bugs for scatter_max * use relative path * fix bugs for multi gpus * remove redundancy Co-authored-by: Yukuo Cen <cenyk1230@qq.com>
- Loading branch information
1 parent
f6f33c6
commit e576766
Showing
8 changed files
with
501 additions
and
349 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,15 @@ | ||
[submodule "third_party/dgNN"] | ||
ignore = dirty | ||
path = third_party/dgNN | ||
url = https://github.com/dgSPARSE/dgNN | ||
url = ../../dgSPARSE/dgNN | ||
branch = main | ||
[submodule "third_party/actnn"] | ||
ignore = dirty | ||
path = third_party/actnn | ||
url = https://github.com/ucbrise/actnn | ||
url = ../../ucbrise/actnn | ||
branch = main | ||
[submodule "third_party/fastmoe"] | ||
ignore = dirty | ||
path = third_party/fastmoe | ||
url = https://github.com/laekov/fastmoe | ||
url = ../../laekov/fastmoe | ||
branch = master |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os | ||
import numpy as np | ||
import torch | ||
from torch.utils.cpp_extension import load | ||
|
||
path = os.path.join(os.path.dirname(__file__)) | ||
|
||
# SPMM | ||
|
||
try: | ||
spmm_max = load( | ||
name="scatter_max", | ||
sources=[os.path.join(path, "scatter_max/scatter_max.cc"), os.path.join(path, "scatter_max/scatter_max.cu")], | ||
verbose=True, | ||
) | ||
|
||
def scatter_max(rowptr, colind, feat): | ||
return ScatterMaxFunction.apply(rowptr, colind, feat) | ||
|
||
|
||
except Exception: | ||
spmm_max = None | ||
|
||
|
||
class ScatterMaxFunction(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, rowptr, colind, feat): | ||
out, max_id = spmm_max.scatter_max_fp(rowptr, colind, feat) | ||
ctx.save_for_backward(max_id) | ||
return out | ||
|
||
@staticmethod | ||
def backward(ctx, grad): | ||
grad = grad.contiguous() | ||
max_id = ctx.saved_tensors[0] | ||
out = spmm_max.scatter_max_bp(grad, max_id) | ||
return None, None, out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#include <pybind11/pybind11.h> | ||
#include <torch/extension.h> | ||
#include <vector> | ||
|
||
void assertTensor(torch::Tensor &T, c10::ScalarType type) { | ||
assert(T.is_contiguous()); | ||
assert(T.device().type() == torch::kCUDA); | ||
assert(T.dtype() == type); | ||
} | ||
|
||
std::vector<torch::Tensor> scatter_max_fp_cuda(torch::Tensor rowptr, | ||
torch::Tensor colind, | ||
torch::Tensor node_feature); | ||
|
||
torch::Tensor scatter_max_bp_cuda(torch::Tensor node_feature, | ||
torch::Tensor max_mask, long num_nodes); | ||
|
||
std::vector<torch::Tensor> scatter_max(torch::Tensor rowptr, | ||
torch::Tensor colind, | ||
torch::Tensor node_feature) { | ||
assertTensor(rowptr, torch::kInt32); | ||
assertTensor(colind, torch::kInt32); | ||
assertTensor(node_feature, torch::kFloat32); | ||
return scatter_max_fp_cuda(rowptr, colind, node_feature); | ||
} | ||
|
||
torch::Tensor scatter_max_bp(torch::Tensor node_feature, | ||
torch::Tensor max_mask) { | ||
assertTensor(node_feature, torch::kFloat32); | ||
assertTensor(max_mask, torch::kInt32); | ||
return scatter_max_bp_cuda(node_feature, max_mask); | ||
} | ||
|
||
PYBIND11_MODULE(scatter_max, m) { | ||
m.doc() = "scatter max kernel"; | ||
m.def("scatter_max_fp", &scatter_max, "scatter max forward"); | ||
m.def("scatter_max_bp", &scatter_max_bp, "scatter max backward"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#include <cuda.h> | ||
#include <torch/types.h> | ||
#include <vector> | ||
|
||
__global__ void scatter_max_forward(const int *A_indptr, const int *A_indices, | ||
const float *B, float *C, int *max_mask) { | ||
int rid = blockDim.y * blockIdx.x + threadIdx.y; | ||
int m = gridDim.x; | ||
int k = blockDim.x; | ||
if (rid < m) { | ||
int lb = A_indptr[rid]; | ||
int hb = A_indptr[(rid + 1)]; | ||
int stride = hb - lb; | ||
int offset; | ||
int max_id; | ||
float acc = (stride > 0) ? FLT_MIN : 0; | ||
for (int ptr = lb; ptr < hb; ptr++) { | ||
int cid = A_indices[ptr]; | ||
offset = cid * k + threadIdx.x; | ||
if (acc < B[offset]) { | ||
acc = B[offset]; | ||
max_id = cid; | ||
} | ||
} | ||
C[(rid * k + threadIdx.x)] = acc; | ||
max_mask[(rid * k + threadIdx.x)] = max_id; | ||
} | ||
} | ||
|
||
__global__ void scatter_max_backward(const float *grad, float *out, | ||
int *max_mask) { | ||
int rid = blockDim.y * blockIdx.x + threadIdx.y; | ||
int m = gridDim.x; | ||
int k = blockDim.x; | ||
if (rid < m) { | ||
int offset = rid * k + threadIdx.x; | ||
int max_id; | ||
max_id = max_mask[offset]; // max mapping | ||
float grad_tmp = grad[offset]; | ||
atomicAdd(&out[max_id * k + threadIdx.x], grad_tmp); | ||
} | ||
} | ||
|
||
std::vector<torch::Tensor> scatter_max_fp_cuda(torch::Tensor rowptr, | ||
torch::Tensor colind, | ||
torch::Tensor node_feature) { | ||
const long m = rowptr.size(0) - 1; | ||
const long k = node_feature.size(1); | ||
auto devid = node_feature.device().index(); | ||
auto optionsI = | ||
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, devid); | ||
auto optionsF = | ||
torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid); | ||
auto max_mask = torch::empty({m, k}, optionsI); | ||
auto out = torch::empty({m, k}, optionsF); | ||
scatter_max_forward<<<m, k>>>(rowptr.data_ptr<int>(), colind.data_ptr<int>(), | ||
node_feature.data_ptr<float>(), | ||
out.data_ptr<float>(), | ||
max_mask.data_ptr<int>()); | ||
return {out, max_mask}; | ||
} | ||
|
||
torch::Tensor scatter_max_bp_cuda(torch::Tensor node_feature, | ||
torch::Tensor max_mask) { | ||
const long m = node_feature.size(0); | ||
const long k = node_feature.size(1); | ||
auto devid = node_feature.device().index(); | ||
auto options = | ||
torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA, devid); | ||
auto out = torch::empty({m, k}, options); | ||
scatter_max_backward<<<m, k>>>(node_feature.data_ptr<float>(), | ||
out.data_ptr<float>(), | ||
max_mask.data_ptr<int>()); | ||
return out; | ||
} |
Oops, something went wrong.