diff --git a/cogdl/operators/spmm.py b/cogdl/operators/spmm.py index 692c4ff4..8c619b9e 100644 --- a/cogdl/operators/spmm.py +++ b/cogdl/operators/spmm.py @@ -31,6 +31,19 @@ def csrspmm(rowptr, colind, x, csr_data, sym=False, actnn=False): csrspmm = None +try: + spmm_cpu = load( + name="spmm_cpu", + extra_cflags=["-fopenmp"], + sources=[os.path.join(path, "spmm/spmm_cpu.cpp")], + verbose=False, + ) + spmm_cpu = spmm_cpu.csr_spmm_cpu +except Exception as e: + print(e) + spmm_cpu = None + + class SPMMFunction(torch.autograd.Function): @staticmethod def forward(ctx, rowptr, colind, feat, edge_weight_csr=None, sym=False): diff --git a/cogdl/operators/spmm/spmm_cpu.cpp b/cogdl/operators/spmm/spmm_cpu.cpp new file mode 100644 index 00000000..bd8683f8 --- /dev/null +++ b/cogdl/operators/spmm/spmm_cpu.cpp @@ -0,0 +1,64 @@ +#include +#include +#include +#include + +torch::Tensor spmm_cpu( + torch::Tensor rowptr, + torch::Tensor colind, + torch::Tensor values, + torch::Tensor dense) +{ + const auto m = rowptr.size(0)-1; + const auto k = dense.size(1); + auto devid = dense.device().index(); + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU, devid); + auto out = torch::empty({m,k}, options); + + int *rowptr_ptr = rowptr.data_ptr(); + int *colind_ptr = colind.data_ptr(); + float *values_ptr = values.data_ptr(); + float *dense_ptr = dense.data_ptr(); + float *out_ptr = out.data_ptr(); + + #pragma omp parallel for schedule(dynamic) + for (int i = 0; i < m; ++i) { + int row_start = rowptr_ptr[i], row_end = rowptr_ptr[i + 1]; + int ik = i * k; + for (int key = row_start; key < row_end; ++key) { + int j = colind_ptr[key] * k; + float val = values_ptr[key]; + for (int t = 0; t < k; ++t) { + out_ptr[ik + t] += val * dense_ptr[j + t]; + } + } + } + return out; +} + +torch::Tensor csr_spmm_cpu( + torch::Tensor A_rowptr, + torch::Tensor A_colind, + torch::Tensor A_csrVal, + torch::Tensor B) +{ + assert(A_rowptr.device().type() == torch::kCPU); + assert(A_colind.device().type() == torch::kCPU); + assert(A_csrVal.device().type() == torch::kCPU); + assert(B.device().type() == torch::kCPU); + assert(A_rowptr.is_contiguous()); + assert(A_colind.is_contiguous()); + assert(A_csrVal.is_contiguous()); + assert(B.is_contiguous()); + assert(A_rowptr.dtype() == torch::kInt32); + assert(A_colind.dtype() == torch::kInt32); + assert(A_csrVal.dtype() == torch::kFloat32); + assert(B.dtype() == torch::kFloat32); + return spmm_cpu(A_rowptr, A_colind, A_csrVal, B); +} + +PYBIND11_MODULE(spmm_cpu, m) +{ + m.doc() = "spmm_cpu in CSR format."; + m.def("csr_spmm_cpu", &csr_spmm_cpu, "CSR SPMM (CPU)"); +} \ No newline at end of file diff --git a/cogdl/utils/spmm_utils.py b/cogdl/utils/spmm_utils.py index 34b58bd4..cfdc2acb 100644 --- a/cogdl/utils/spmm_utils.py +++ b/cogdl/utils/spmm_utils.py @@ -6,9 +6,11 @@ "csrmhspmm": None, "csr_edge_softmax": None, "fused_gat_func": None, + "fast_spmm_cpu": None, "spmm_flag": False, "mh_spmm_flag": False, "fused_gat_flag": False, + "spmm_cpu_flag": False, } @@ -28,6 +30,16 @@ def initialize_spmm(): # print("Failed to load fast version of SpMM, use torch.scatter_add instead.") +def initialize_spmm_cpu(): + if CONFIGS["spmm_cpu_flag"]: + return + CONFIGS["spmm_cpu_flag"] = True + + from cogdl.operators.spmm import spmm_cpu + + CONFIGS["fast_spmm_cpu"] = spmm_cpu + + def spmm_scatter(row, col, values, b): r""" Args: @@ -40,6 +52,36 @@ def spmm_scatter(row, col, values, b): return output +def spmm_cpu(graph, x, fast_spmm_cpu=None): + if fast_spmm_cpu is None: + initialize_spmm_cpu() + fast_spmm_cpu = CONFIGS["fast_spmm_cpu"] + if fast_spmm_cpu is not None and str(x.device) == "cpu": + if graph.out_norm is not None: + x = graph.out_norm * x + + row_ptr, col_indices = graph.row_indptr, graph.col_indices + csr_data = graph.raw_edge_weight + x = fast_spmm_cpu(row_ptr.int(), col_indices.int(), csr_data, x) + + if graph.in_norm is not None: + x = graph.in_norm * x + else: + row, col = graph.edge_index + x = spmm_scatter(row, col, graph.edge_weight, x) + return x + + +class SpMM_CPU(torch.nn.Module): + def __init__(self): + super().__init__() + initialize_spmm_cpu() + self.fast_spmm_cpu = CONFIGS["fast_spmm_cpu"] + + def forward(self, graph, x): + return spmm_cpu(graph, x, self.fast_spmm_cpu) + + def spmm(graph, x, actnn=False, fast_spmm=None): if fast_spmm is None: initialize_spmm() diff --git a/examples/notebooks/quickstart.ipynb b/examples/notebooks/quickstart.ipynb index 2774c40e..6a0754ba 100644 --- a/examples/notebooks/quickstart.ipynb +++ b/examples/notebooks/quickstart.ipynb @@ -179,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -192,148 +192,174 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m[I 2021-07-12 13:47:32,783]\u001b[0m A new study created in memory with name: no-name-5d047d75-cfbe-465b-bfc6-d5e01052dec9\u001b[0m\n", - "Epoch: 022, Train: 0.8571, Val: 0.5780, ValLoss: 1.8665: 3%|▎ | 14/500 [00:00<00:03, 135.58it/s]" + "\u001b[32m[I 2021-11-20 22:49:00,335]\u001b[0m A new study created in memory with name: no-name-9dd402de-0057-4957-a224-90f5f9f5fda1\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Namespace(activation='relu', checkpoint=None, cpu=False, dataset='cora', device_id=[0], dropout=0.7571220636246204, fast_spmm=False, func_search=, hidden_size=32, inference=False, lr=0.005, max_epoch=500, missing_rate=0, model='gcn', norm=None, num_classes=None, num_features=None, num_layers=2, patience=100, residual=False, save_dir='.', save_model=None, seed=1, task='node_classification', trainer=None, use_best_config=False, weight_decay=0.0005)\n" + "Namespace(activation='relu', actnn=False, checkpoint_path='./checkpoints/model.pt', cpu=False, cpu_inference=False, dataset=['cora'], devices=[0], distributed=False, dropout=0.5, dw='node_classification_dw', epochs=500, eval_step=1, hidden_size=64, load_emb_path=None, local_rank=0, log_path='.', logger=None, lr=0.01, master_addr='localhost', master_port=13425, max_epoch=None, model=['gcn'], mw='node_classification_mw', n_trials=3, n_warmup_steps=0, no_test=False, norm=None, nstage=1, num_classes=None, num_features=None, num_layers=2, patience=100, progress_bar='epoch', project='cogdl-exp', residual=False, resume_training=False, rp_ratio=1, save_emb_path=None, search_space=, seed=[1, 2], split=[0], unsup=False, use_best_config=False, weight_decay=0)\n", + "{'lr': 0.001, 'hidden_size': 64, 'dropout': 0.7053281088699626}\n", + " \n", + "|----------------------------------------------------------------------------------------|\n", + " *** Running (`cora`, `gcn`, `node_classification_dw`, `node_classification_mw`)\n", + "|----------------------------------------------------------------------------------------|\n", + "Model Parameters: 92231\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Epoch: 499, Train: 1.0000, Val: 0.7940, ValLoss: 0.7615: 100%|██████████| 500/500 [00:03<00:00, 133.73it/s]\n", - "Epoch: 020, Train: 0.9429, Val: 0.7000, ValLoss: 1.8790: 3%|▎ | 13/500 [00:00<00:03, 123.07it/s]" + "Epoch: 499, train_loss: 0.1821, val_acc: 0.7860: 100%|██████████| 500/500 [00:03<00:00, 154.53it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Valid accurracy = 0.7820\n", - "Test accuracy = 0.8110\n", - "Namespace(activation='relu', checkpoint=None, cpu=False, dataset='cora', device_id=[0], dropout=0.7571220636246204, fast_spmm=False, func_search=, hidden_size=32, inference=False, lr=0.005, max_epoch=500, missing_rate=0, model='gcn', norm=None, num_classes=None, num_features=None, num_layers=2, patience=100, residual=False, save_dir='.', save_model=None, seed=2, task='node_classification', trainer=None, use_best_config=False, weight_decay=0.0005)\n" + "Saving 452-th model to ./checkpoints/model.pt ...\n", + "Loading model from ./checkpoints/model.pt ...\n", + "{'test_acc': 0.813, 'val_acc': 0.79}\n", + " \n", + "|----------------------------------------------------------------------------------------|\n", + " *** Running (`cora`, `gcn`, `node_classification_dw`, `node_classification_mw`)\n", + "|----------------------------------------------------------------------------------------|\n", + "Model Parameters: 92231\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Epoch: 499, Train: 1.0000, Val: 0.7860, ValLoss: 0.7639: 100%|██████████| 500/500 [00:03<00:00, 132.70it/s]\n", - "\u001b[32m[I 2021-07-12 13:47:40,393]\u001b[0m Trial 0 finished with value: 0.783 and parameters: {'lr': 0.005, 'hidden_size': 32, 'dropout': 0.7571220636246204}. Best is trial 0 with value: 0.783.\u001b[0m\n", - "Epoch: 016, Train: 0.9714, Val: 0.7560, ValLoss: 1.8282: 2%|▏ | 12/500 [00:00<00:04, 118.41it/s]" + "Epoch: 402, train_loss: 0.2834, val_acc: 0.7840: 81%|████████ | 403/500 [00:02<00:00, 152.21it/s]\n", + "\u001b[32m[I 2021-11-20 22:49:06,293]\u001b[0m Trial 0 finished with value: 0.792 and parameters: {'lr': 0.001, 'hidden_size': 64, 'dropout': 0.7053281088699626}. Best is trial 0 with value: 0.792.\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Valid accurracy = 0.7840\n", - "Test accuracy = 0.8090\n", - "| Variant | Acc | ValAcc |\n", + "Saving 303-th model to ./checkpoints/model.pt ...\n", + "Loading model from ./checkpoints/model.pt ...\n", + "{'test_acc': 0.812, 'val_acc': 0.794}\n", + "| Variant | test_acc | val_acc |\n", "|-----------------|---------------|---------------|\n", - "| ('cora', 'gcn') | 0.8100±0.0010 | 0.7830±0.0010 |\n", - "Namespace(activation='relu', checkpoint=None, cpu=False, dataset='cora', device_id=[0], dropout=0.6352487471594475, fast_spmm=False, func_search=, hidden_size=64, inference=False, lr=0.005, max_epoch=500, missing_rate=0, model='gcn', norm=None, num_classes=None, num_features=None, num_layers=2, patience=100, residual=False, save_dir='.', save_model=None, seed=1, task='node_classification', trainer=None, use_best_config=False, weight_decay=0.0005)\n" + "| ('cora', 'gcn') | 0.8125±0.0005 | 0.7920±0.0020 |\n", + "{'lr': 0.005, 'hidden_size': 128, 'dropout': 0.7457077778269245}\n", + " \n", + "|----------------------------------------------------------------------------------------|\n", + " *** Running (`cora`, `gcn`, `node_classification_dw`, `node_classification_mw`)\n", + "|----------------------------------------------------------------------------------------|\n", + "Model Parameters: 184455\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Epoch: 499, Train: 1.0000, Val: 0.7960, ValLoss: 0.7289: 100%|██████████| 500/500 [00:03<00:00, 129.73it/s]\n", - "Epoch: 020, Train: 0.9571, Val: 0.7680, ValLoss: 1.7769: 3%|▎ | 13/500 [00:00<00:04, 120.40it/s]" + "Epoch: 436, train_loss: 0.0047, val_acc: 0.7900: 87%|████████▋ | 437/500 [00:02<00:00, 164.92it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Valid accurracy = 0.7920\n", - "Test accuracy = 0.8120\n", - "Namespace(activation='relu', checkpoint=None, cpu=False, dataset='cora', device_id=[0], dropout=0.6352487471594475, fast_spmm=False, func_search=, hidden_size=64, inference=False, lr=0.005, max_epoch=500, missing_rate=0, model='gcn', norm=None, num_classes=None, num_features=None, num_layers=2, patience=100, residual=False, save_dir='.', save_model=None, seed=2, task='node_classification', trainer=None, use_best_config=False, weight_decay=0.0005)\n" + "Saving 337-th model to ./checkpoints/model.pt ...\n", + "Loading model from ./checkpoints/model.pt ...\n", + "{'test_acc': 0.8, 'val_acc': 0.796}\n", + " \n", + "|----------------------------------------------------------------------------------------|\n", + " *** Running (`cora`, `gcn`, `node_classification_dw`, `node_classification_mw`)\n", + "|----------------------------------------------------------------------------------------|\n", + "Model Parameters: 184455\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Epoch: 499, Train: 1.0000, Val: 0.7880, ValLoss: 0.7334: 100%|██████████| 500/500 [00:03<00:00, 129.88it/s]\n", - "\u001b[32m[I 2021-07-12 13:47:48,174]\u001b[0m Trial 1 finished with value: 0.795 and parameters: {'lr': 0.005, 'hidden_size': 64, 'dropout': 0.6352487471594475}. Best is trial 1 with value: 0.795.\u001b[0m\n", - "Epoch: 017, Train: 0.8571, Val: 0.4840, ValLoss: 1.9058: 2%|▏ | 12/500 [00:00<00:04, 119.36it/s]" + "Epoch: 221, train_loss: 0.0144, val_acc: 0.7880: 44%|████▍ | 222/500 [00:01<00:01, 164.91it/s]\n", + "\u001b[32m[I 2021-11-20 22:49:10,357]\u001b[0m Trial 1 finished with value: 0.795 and parameters: {'lr': 0.005, 'hidden_size': 128, 'dropout': 0.7457077778269245}. Best is trial 1 with value: 0.795.\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Valid accurracy = 0.7980\n", - "Test accuracy = 0.8110\n", - "| Variant | Acc | ValAcc |\n", + "Saving 122-th model to ./checkpoints/model.pt ...\n", + "Loading model from ./checkpoints/model.pt ...\n", + "{'test_acc': 0.807, 'val_acc': 0.794}\n", + "| Variant | test_acc | val_acc |\n", "|-----------------|---------------|---------------|\n", - "| ('cora', 'gcn') | 0.8115±0.0005 | 0.7950±0.0030 |\n", - "Namespace(activation='relu', checkpoint=None, cpu=False, dataset='cora', device_id=[0], dropout=0.5363238401043926, fast_spmm=False, func_search=, hidden_size=128, inference=False, lr=0.001, max_epoch=500, missing_rate=0, model='gcn', norm=None, num_classes=None, num_features=None, num_layers=2, patience=100, residual=False, save_dir='.', save_model=None, seed=1, task='node_classification', trainer=None, use_best_config=False, weight_decay=0.0005)\n" + "| ('cora', 'gcn') | 0.8035±0.0035 | 0.7950±0.0010 |\n", + "{'lr': 0.01, 'hidden_size': 32, 'dropout': 0.7366976302737803}\n", + " \n", + "|----------------------------------------------------------------------------------------|\n", + " *** Running (`cora`, `gcn`, `node_classification_dw`, `node_classification_mw`)\n", + "|----------------------------------------------------------------------------------------|\n", + "Model Parameters: 46119\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Epoch: 499, Train: 1.0000, Val: 0.7860, ValLoss: 0.8297: 100%|██████████| 500/500 [00:04<00:00, 119.69it/s]\n", - "Epoch: 021, Train: 0.9357, Val: 0.6740, ValLoss: 1.8943: 3%|▎ | 13/500 [00:00<00:03, 126.40it/s]" + "Epoch: 222, train_loss: 0.0549, val_acc: 0.7860: 45%|████▍ | 223/500 [00:01<00:01, 152.79it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Valid accurracy = 0.7860\n", - "Test accuracy = 0.8140\n", - "Namespace(activation='relu', checkpoint=None, cpu=False, dataset='cora', device_id=[0], dropout=0.5363238401043926, fast_spmm=False, func_search=, hidden_size=128, inference=False, lr=0.001, max_epoch=500, missing_rate=0, model='gcn', norm=None, num_classes=None, num_features=None, num_layers=2, patience=100, residual=False, save_dir='.', save_model=None, seed=2, task='node_classification', trainer=None, use_best_config=False, weight_decay=0.0005)\n" + "Saving 123-th model to ./checkpoints/model.pt ...\n", + "Loading model from ./checkpoints/model.pt ...\n", + "{'test_acc': 0.817, 'val_acc': 0.796}\n", + " \n", + "|----------------------------------------------------------------------------------------|\n", + " *** Running (`cora`, `gcn`, `node_classification_dw`, `node_classification_mw`)\n", + "|----------------------------------------------------------------------------------------|\n", + "Model Parameters: 46119\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Epoch: 499, Train: 1.0000, Val: 0.7940, ValLoss: 0.8240: 100%|██████████| 500/500 [00:04<00:00, 120.83it/s]\n", - "\u001b[32m[I 2021-07-12 13:47:56,568]\u001b[0m Trial 2 finished with value: 0.79 and parameters: {'lr': 0.001, 'hidden_size': 128, 'dropout': 0.5363238401043926}. Best is trial 1 with value: 0.795.\u001b[0m\n" + "Epoch: 214, train_loss: 0.0458, val_acc: 0.7760: 43%|████▎ | 215/500 [00:01<00:01, 152.77it/s]\n", + "\u001b[32m[I 2021-11-20 22:49:13,296]\u001b[0m Trial 2 finished with value: 0.797 and parameters: {'lr': 0.01, 'hidden_size': 32, 'dropout': 0.7366976302737803}. Best is trial 2 with value: 0.797.\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Valid accurracy = 0.7940\n", - "Test accuracy = 0.8190\n", - "| Variant | Acc | ValAcc |\n", + "Saving 115-th model to ./checkpoints/model.pt ...\n", + "Loading model from ./checkpoints/model.pt ...\n", + "{'test_acc': 0.804, 'val_acc': 0.798}\n", + "| Variant | test_acc | val_acc |\n", "|-----------------|---------------|---------------|\n", - "| ('cora', 'gcn') | 0.8165±0.0025 | 0.7900±0.0040 |\n", - "{'lr': 0.005, 'hidden_size': 64, 'dropout': 0.6352487471594475}\n", + "| ('cora', 'gcn') | 0.8105±0.0065 | 0.7970±0.0010 |\n", + "{'lr': 0.01, 'hidden_size': 32, 'dropout': 0.7366976302737803}\n", "\n", "Final results:\n", "\n", - "| Variant | Acc | ValAcc |\n", + "| Variant | test_acc | val_acc |\n", "|-----------------|---------------|---------------|\n", - "| ('cora', 'gcn') | 0.8115±0.0005 | 0.7950±0.0030 |\n" + "| ('cora', 'gcn') | 0.8105±0.0065 | 0.7970±0.0010 |\n" ] }, { "data": { "text/plain": [ "defaultdict(list,\n", - " {('cora', 'gcn'): [{'Acc': 0.812, 'ValAcc': 0.792},\n", - " {'Acc': 0.811, 'ValAcc': 0.798}]})" + " {('cora', 'gcn'): [{'test_acc': 0.817, 'val_acc': 0.796},\n", + " {'test_acc': 0.804, 'val_acc': 0.798}]})" ] }, - "execution_count": 4, - "metadata": { - "tags": [] - }, + "execution_count": 3, + "metadata": {}, "output_type": "execute_result" } ],