Skip to content

Commit

Permalink
[Model] Update GCC model (#392)
Browse files Browse the repository at this point in the history
* fixed flake8

* update parallel

* delete test_gcc

* update readme

* Update .gitignore

* fixed RWR

* update test_gcc

* update code

Co-authored-by: Yukuo Cen <cenyk1230@qq.com>
  • Loading branch information
hwangyeong and cenyk1230 authored Dec 7, 2022
1 parent ab37cd1 commit 5b62c61
Show file tree
Hide file tree
Showing 22 changed files with 700 additions and 89 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,5 @@ metis*
*.dict
*.csv
*.sql
*.pt
*.pt
*.npz
12 changes: 6 additions & 6 deletions cogdl/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,11 @@ def to_networkx(self, weighted=True):
gnx.add_edges_from(edges)
return gnx

def random_walk(self, seeds, length=1, restart_p=0.0):
def random_walk(self, seeds, length=1, restart_p=0.0, parallel=True):
if not hasattr(self, "__walker__"):
scipy_adj = self.to_scipy_csr()
self.__walker__ = RandomWalker(scipy_adj)
return self.__walker__.walk(seeds, length, restart_p=restart_p)
return self.__walker__.walk(seeds, length, restart_p=restart_p, parallel=parallel)

@staticmethod
def from_dict(dictionary):
Expand Down Expand Up @@ -915,11 +915,11 @@ def edge_subgraph(self, edge_idx, require_idx=True):
else:
return g

def random_walk(self, seeds, max_nodes_per_seed, restart_p=0.0):
return self._adj.random_walk(seeds, max_nodes_per_seed, restart_p)
def random_walk(self, seeds, max_nodes_per_seed, restart_p=0.0, parallel=True):
return self._adj.random_walk(seeds, max_nodes_per_seed, restart_p, parallel)

def random_walk_with_restart(self, seeds, max_nodes_per_seed, restart_p=0.0):
return self._adj.random_walk(seeds, max_nodes_per_seed, restart_p)
def random_walk_with_restart(self, seeds, max_nodes_per_seed, restart_p=0.0, parallel=True):
return self._adj.random_walk(seeds, max_nodes_per_seed, restart_p, parallel)

def to_scipy_csr(self):
return self._adj.to_scipy_csr()
Expand Down
1 change: 1 addition & 0 deletions cogdl/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections
import os.path as osp
from itertools import repeat

import numpy as np

import torch.utils.data
Expand Down
29 changes: 28 additions & 1 deletion cogdl/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,30 @@ def build_dataset_from_name(dataset, split=0):
return dataset_class()


def build_dataset_pretrain(args):
args.pretrain = False
dataset_names = args.dataset
if ' ' in args.dataset:
datasets_name = args.dataset.split(' ')
dataset = []
for dataset_ in datasets_name:
args.dataset = dataset_
dataset.append(build_dataset(args))
else:
dataset = [build_dataset(args)]
args.pretrain = True
args.dataset = dataset_names
dataset_class = getattr(importlib.import_module("cogdl.datasets.gcc_data"), "PretrainDataset")
return dataset_class(args.dataset, [x.data for x in dataset])


def build_dataset(args):
if not hasattr(args, "split"):
args.split = 0
dataset = build_dataset_from_name(args.dataset, args.split)
if not hasattr(args, "pretrain") or not args.pretrain:
dataset = build_dataset_from_name(args.dataset, args.split)
else:
dataset = build_dataset_pretrain(args)

if hasattr(dataset, "num_classes") and dataset.num_classes > 0:
args.num_classes = dataset.num_classes
Expand Down Expand Up @@ -96,10 +116,17 @@ def build_dataset_from_path(data_path, dataset=None):


SUPPORTED_DATASETS = {
"gcc_academic": "cogdl.datasets.gcc_data.Academic_GCCDataset",
"gcc_dblp_netrep": "cogdl.datasets.gcc_data.DBLPNetrep_GCCDataset",
"gcc_dblp_snap": "cogdl.datasets.gcc_data.DBLPSnap_GCCDataset",
"gcc_facebook": "cogdl.datasets.gcc_data.Facebook_GCCDataset",
"gcc_imdb": "cogdl.datasets.gcc_data.IMDB_GCCDataset",
"gcc_livejournal": "cogdl.datasets.gcc_data.Livejournal_GCCDataset",
"kdd_icdm": "cogdl.datasets.gcc_data.KDD_ICDM_GCCDataset",
"sigir_cikm": "cogdl.datasets.gcc_data.SIGIR_CIKM_GCCDataset",
"sigmod_icde": "cogdl.datasets.gcc_data.SIGMOD_ICDE_GCCDataset",
"usa-airport": "cogdl.datasets.gcc_data.USAAirportDataset",
"h-index": "cogdl.datasets.gcc_data.HIndexDataset",
"ogbn-arxiv": "cogdl.datasets.ogb.OGBArxivDataset",
"ogbn-products": "cogdl.datasets.ogb.OGBProductsDataset",
"ogbn-proteins": "cogdl.datasets.ogb.OGBProteinsDataset",
Expand Down
107 changes: 104 additions & 3 deletions cogdl/datasets/gcc_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from cogdl.data import Graph, Dataset
from cogdl.utils import download_url
from cogdl.utils import Accuracy, CrossEntropyLoss


class GCCDataset(Dataset):
Expand Down Expand Up @@ -96,7 +97,10 @@ def __init__(self, root, name):

@property
def raw_file_names(self):
names = ["edgelist.txt", "nodelabel.txt"]
if self.name in UNLABELED_GCCDATASETS:
names = ["edgelist.txt"]
else:
names = ["edgelist.txt", "nodelabel.txt"]
return names

@property
Expand Down Expand Up @@ -142,11 +146,11 @@ def process(self):
if label not in label2id:
label2id[label] = len(label2id)
nodes.append(node2id[x])
if "hindex" in self.name:
if "h-index" in self.name:
labels.append(label)
else:
labels.append(label2id[label])
if "hindex" in self.name:
if "h-index" in self.name:
median = np.median(labels)
labels = [int(label > median) for label in labels]
assert num_nodes == len(set(nodes))
Expand All @@ -158,6 +162,45 @@ def process(self):
torch.save(data, self.processed_paths[0])


class PretrainDataset(object):

class DataList(object):

def __init__(self, graphs):
for graph in graphs:
graph.y = None
self.graphs = graphs

def to(self, device):
return [graph.to(device) for graph in self.graphs]

def train(self):
return [graph.train() for graph in self.graphs]

def eval(self):
return [graph.eval() for graph in self.graphs]

def __init__(self, name, data):
super(PretrainDataset, self).__init__()
self.name = name
# self.data = data
self.data = self.DataList(data)

def get_evaluator(self):
return Accuracy()

def get_loss_fn(self):
return CrossEntropyLoss()

@property
def num_features(self):
return 0

def get(self, idx):
assert idx == 0
return self.data.graphs


class KDD_ICDM_GCCDataset(GCCDataset):
def __init__(self, data_path="data"):
dataset = "kdd_icdm"
Expand All @@ -184,3 +227,61 @@ def __init__(self, data_path="data"):
dataset = "usa-airport"
path = osp.join(data_path, dataset)
super(USAAirportDataset, self).__init__(path, dataset)


class HIndexDataset(Edgelist):
def __init__(self, data_path="data"):
dataset = "h-index"
path = osp.join(data_path, dataset)
super(HIndexDataset, self).__init__(path, dataset)


class Academic_GCCDataset(Edgelist):
def __init__(self, data_path="data"):
dataset = "gcc_academic"
path = osp.join(data_path, dataset)
super(Academic_GCCDataset, self).__init__(path, dataset)


class DBLPNetrep_GCCDataset(Edgelist):
def __init__(self, data_path="data"):
dataset = "gcc_dblp_netrep"
path = osp.join(data_path, dataset)
super(DBLPNetrep_GCCDataset, self).__init__(path, dataset)


class DBLPSnap_GCCDataset(Edgelist):
def __init__(self, data_path="data"):
dataset = "gcc_dblp_snap"
path = osp.join(data_path, dataset)
super(DBLPSnap_GCCDataset, self).__init__(path, dataset)


class Facebook_GCCDataset(Edgelist):
def __init__(self, data_path="data"):
dataset = "gcc_facebook"
path = osp.join(data_path, dataset)
super(Facebook_GCCDataset, self).__init__(path, dataset)


class IMDB_GCCDataset(Edgelist):
def __init__(self, data_path="data"):
dataset = "gcc_imdb"
path = osp.join(data_path, dataset)
super(IMDB_GCCDataset, self).__init__(path, dataset)


class Livejournal_GCCDataset(Edgelist):
def __init__(self, data_path="data"):
dataset = "gcc_livejournal"
path = osp.join(data_path, dataset)
super(Livejournal_GCCDataset, self).__init__(path, dataset)


UNLABELED_GCCDATASETS = ["gcc_academic",
"gcc_dblp_netrep",
"gcc_dblp_snap",
"gcc_facebook",
"gcc_imdb",
"gcc_livejournal"
]
12 changes: 11 additions & 1 deletion cogdl/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import optuna
from tabulate import tabulate

from cogdl.utils import set_random_seed, tabulate_results
from cogdl.utils import set_random_seed, tabulate_results, build_model_path
from cogdl.configs import BEST_CONFIGS
from cogdl.data import Dataset
from cogdl.models import build_model
Expand Down Expand Up @@ -111,6 +111,9 @@ def train(args): # noqa: C901
|-------------------------------------{'-' * (len(str(args.dataset)) + len(model_name) + len(dw_name) + len(mw_name))}|"""
)

if hasattr(args, "save_model_path"):
args = build_model_path(args, model_name)

if getattr(args, "use_best_config", False):
args = set_best_config(args)

Expand Down Expand Up @@ -181,6 +184,12 @@ def train(args): # noqa: C901
if hasattr(args, "hidden_size"):
optimizer_cfg["hidden_size"] = args.hidden_size

if hasattr(args, "beta1") and hasattr(args, "beta2"):
optimizer_cfg["betas"] = (args.beta1, args.beta2)

if hasattr(dataset_wrapper, "train_dataset"):
optimizer_cfg["total"] = len(dataset_wrapper.train_dataset)

# setup model_wrapper
if isinstance(args.mw, str) and "embedding" in args.mw:
model_wrapper = mw_class(model, **model_wrapper_args)
Expand Down Expand Up @@ -212,6 +221,7 @@ def train(args): # noqa: C901
fp16=args.fp16,
do_test=args.do_test,
do_valid=args.do_valid,
clip_grad_norm=args.clip_grad_norm,
)

# Go!!!
Expand Down
1 change: 1 addition & 0 deletions cogdl/layers/mlp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self.activation = get_activation(activation)
self.act_first = act_first
self.dropout = dropout
self.output_dim = out_feats
shapes = [in_feats] + [hidden_size] * (num_layers - 1) + [out_feats]
self.mlp = nn.ModuleList(
[nn.Linear(shapes[layer], shapes[layer + 1], bias=bias) for layer in range(num_layers)]
Expand Down
8 changes: 4 additions & 4 deletions cogdl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def build_model(args):


SUPPORTED_MODELS = {
"transe":"cogdl.models.emb.transe.TransE",
"complex":"cogdl.models.emb.complex.ComplEx",
"distmult":"cogdl.models.emb.distmult.DistMult",
"rotate":"cogdl.models.emb.rotate.RotatE",
"transe": "cogdl.models.emb.transe.TransE",
"complex": "cogdl.models.emb.complex.ComplEx",
"distmult": "cogdl.models.emb.distmult.DistMult",
"rotate": "cogdl.models.emb.rotate.RotatE",
"hope": "cogdl.models.emb.hope.HOPE",
"spectral": "cogdl.models.emb.spectral.Spectral",
"hin2vec": "cogdl.models.emb.hin2vec.Hin2vec",
Expand Down
18 changes: 14 additions & 4 deletions cogdl/models/nn/gcc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
train_eps=False,
dropout=0.5,
final_dropout=0.2,
use_selayer=False,
):
super(GINModel, self).__init__()
self.gin_layers = nn.ModuleList()
Expand All @@ -78,7 +79,7 @@ def __init__(
mlp = MLP(in_feats, hidden_dim, hidden_dim, num_mlp_layers, norm="batchnorm")
else:
mlp = MLP(hidden_dim, hidden_dim, hidden_dim, num_mlp_layers, norm="batchnorm")
self.gin_layers.append(GINLayer(mlp, eps, train_eps))
self.gin_layers.append(GINLayer(ApplyNodeFunc(mlp, use_selayer), eps, train_eps))
self.batch_norm.append(nn.BatchNorm1d(hidden_dim))

self.linear_prediction = nn.ModuleList()
Expand Down Expand Up @@ -155,9 +156,12 @@ def add_args(parser):
parser.add_argument("--max-edge-freq", type=int, default=16)
parser.add_argument("--max-degree", type=int, default=512)
parser.add_argument("--freq-embedding-size", type=int, default=16)
parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument("--num-layers", type=int, default=5)
parser.add_argument("--num-heads", type=int, default=2)
parser.add_argument("--output-size", type=int, default=32)
parser.add_argument("--output-size", type=int, default=64)
parser.add_argument("--norm", type=bool, default=True)
parser.add_argument("--gnn-model", type=str, default="gin")
parser.add_argument("--degree-input", type=bool, default=True)

@classmethod
def build_model_from_args(cls, args):
Expand All @@ -170,7 +174,10 @@ def build_model_from_args(cls, args):
num_heads=args.num_heads,
degree_embedding_size=args.degree_embedding_size,
node_hidden_dim=args.hidden_size,
norm=args.norm,
gnn_model=args.gnn_model,
output_dim=args.output_size,
degree_input=args.degree_input
)

def __init__(
Expand All @@ -190,7 +197,7 @@ def __init__(
num_layer_set2set=3,
norm=False,
gnn_model="gin",
degree_input=False,
degree_input=True,
):
super(GCCModel, self).__init__()

Expand Down Expand Up @@ -229,6 +236,8 @@ def __init__(
self.max_edge_freq = max_edge_freq
self.max_degree = max_degree
self.degree_input = degree_input
self.output_dim = output_dim
self.hidden_size = node_hidden_dim

# self.node_freq_embedding = nn.Embedding(
# num_embeddings=max_node_freq + 1, embedding_dim=freq_embedding_size
Expand Down Expand Up @@ -277,6 +286,7 @@ def forward(self, g, return_all_outputs=False):
if device != torch.device("cpu"):
degrees = degrees.cuda(device)

degrees = degrees.long()
deg_emb = self.degree_embedding(degrees.clamp(0, self.max_degree))

n_feat = torch.cat((pos_undirected, deg_emb, seed_emb), dim=-1)
Expand Down
3 changes: 2 additions & 1 deletion cogdl/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def get_parser():
parser.add_argument("--patience", type=int, default=100)
parser.add_argument("--lr", default=0.01, type=float)
parser.add_argument("--weight-decay", default=0, type=float)
parser.add_argument("--n-warmup-steps", type=int, default=0)
parser.add_argument("--n-warmup-steps", type=float, default=0.)
parser.add_argument("--split", default=[0], type=int, nargs="+", metavar="N")
parser.add_argument("--clip-grad-norm", default=5., type=float)

parser.add_argument("--checkpoint-path", type=str, default="./checkpoints/model.pt", help="path to save model")
parser.add_argument("--save-emb-path", type=str, default=None, help="path to save embeddings")
Expand Down
2 changes: 1 addition & 1 deletion cogdl/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def run(self, model_w: ModelWrapper, dataset_w: DataWrapper):

# clear the GPU memory
dataset = dataset_w.get_dataset()
if isinstance(dataset.data, Graph):
if isinstance(dataset.data, Graph) or hasattr(dataset.data, "graphs"):
dataset.data.to("cpu")

return final_test
Expand Down
Loading

0 comments on commit 5b62c61

Please sign in to comment.