Skip to content

Commit

Permalink
[Feature] Support ogbn-proteins with edge_feature (#261)
Browse files Browse the repository at this point in the history
* Support ogbn-proteins with edge_feature

* fix bugs

* modify revgnn

* fix bugs

* fix typo
  • Loading branch information
THINK2TRY authored Jul 24, 2021
1 parent 81d8989 commit c4d36c2
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 72 deletions.
9 changes: 5 additions & 4 deletions cogdl/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def add_remaining_self_loops(self):
self.row_ptr, reindex = coo2csr_index(self.row, self.col, num_nodes=self.num_nodes)
self.row = self.row[reindex]
self.col = self.col[reindex]
self.attr = None
# if self.attr is not None:

def remove_self_loops(self):
mask = self.row == self.col
Expand Down Expand Up @@ -276,10 +276,11 @@ def edge_index(self):
@edge_index.setter
def edge_index(self, edge_index):
row, col = edge_index
if self.row is not None and self.row.shape == row.shape:
return
# if self.row is not None and self.row.shape == row.shape:
# return
self.row, self.col = row, col
self.convert_csr()
# self.convert_csr()
self.row_ptr = None

@property
def row_indptr(self):
Expand Down
4 changes: 4 additions & 0 deletions cogdl/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def num_classes(self):
return 0
return y.max().item() + 1 if y.dim() == 1 else y.size(1)

@property
def edge_attr_size(self):
return None

def __repr__(self): # pragma: no cover
return "{}({})".format(self.__class__.__name__, len(self))

Expand Down
4 changes: 3 additions & 1 deletion cogdl/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def __getitem__(self, idx):
batch = self.batch_idx[idx * self.batch_size : (idx + 1) * self.batch_size]
nodes = np.concatenate([self.clusters[i] for i in batch])
subgraph = self.data.subgraph(nodes)
subgraph.batch = torch.from_numpy(nodes)
return subgraph

def preprocess(self, n_cluster):
Expand Down Expand Up @@ -557,12 +558,13 @@ class RandomPartitionDataset(torch.utils.data.Dataset):
def __init__(self, dataset, n_cluster):
self.data = dataset.data
self.n_cluster = n_cluster
self.num_nodes = dataset.data.x.shape[0]
self.num_nodes = dataset.data.num_nodes
self.parts = torch.randint(0, self.n_cluster, size=(self.num_nodes,))

def __getitem__(self, idx):
node_cluster = torch.where(self.parts == idx)[0]
subgraph = self.data.subgraph(node_cluster)
subgraph.batch = node_cluster
return subgraph

def __len__(self):
Expand Down
128 changes: 96 additions & 32 deletions cogdl/datasets/ogb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,61 +8,69 @@

from . import register_dataset
from cogdl.data import Dataset, Graph, DataLoader
from cogdl.utils import cross_entropy_loss, accuracy, remove_self_loops, coalesce
from cogdl.utils import cross_entropy_loss, accuracy, remove_self_loops, coalesce, bce_with_logits_loss
from torch_geometric.utils import to_undirected


class OGBNDataset(Dataset):
def __init__(self, root, name):
def __init__(self, root, name, transform=None):
name = name.replace("-", "_")
self.name = name
root = os.path.join(root, name)
super(OGBNDataset, self).__init__(root)
dataset = NodePropPredDataset(name, root)
self.transform = None
self.data = torch.load(self.processed_paths[0])

def get(self, idx):
assert idx == 0
return self.data

def get_loss_fn(self):
return cross_entropy_loss

def get_evaluator(self):
return accuracy

def _download(self):
pass

@property
def processed_file_names(self):
return "data_cogdl.pt"

def process(self):
name = self.name.replace("_", "-")
dataset = NodePropPredDataset(name, self.root)
graph, y = dataset[0]
x = torch.tensor(graph["node_feat"])
x = torch.tensor(graph["node_feat"]) if graph["node_feat"] is not None else None
y = torch.tensor(y.squeeze())
row, col = graph["edge_index"][0], graph["edge_index"][1]
row = torch.from_numpy(row)
col = torch.from_numpy(col)
edge_index = torch.stack([row, col], dim=0)
edge_attr = None
edge_attr = torch.as_tensor(graph["edge_feat"]) if graph["edge_feat"] is not None else graph["edge_feat"]
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
row = torch.cat([edge_index[0], edge_index[1]])
col = torch.cat([edge_index[1], edge_index[0]])

row, col, _ = coalesce(row, col)
edge_index = torch.stack([row, col], dim=0)
if edge_attr is not None:
edge_attr = torch.cat([edge_attr, edge_attr], dim=0)

self.data = Graph(x=x, edge_index=edge_index, edge_weight=edge_attr, y=y)
self.data.num_nodes = graph["num_nodes"]
assert self.data.num_nodes == self.data.x.shape[0]
data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
data.num_nodes = graph["num_nodes"]

# split
split_index = dataset.get_idx_split()
self.data.train_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
self.data.test_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
self.data.val_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
self.data.train_mask[split_index["train"]] = True
self.data.test_mask[split_index["test"]] = True
self.data.val_mask[split_index["valid"]] = True
data.train_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)
data.val_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)
data.test_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)

self.transform = None
data.train_mask[split_index["train"]] = True
data.test_mask[split_index["test"]] = True
data.val_mask[split_index["valid"]] = True

def get(self, idx):
assert idx == 0
return self.data

def get_loss_fn(self):
return cross_entropy_loss

def get_evaluator(self):
return accuracy

def _download(self):
pass

def _process(self):
pass
torch.save(data, self.processed_paths[0])
return data


@register_dataset("ogbn-arxiv")
Expand Down Expand Up @@ -96,6 +104,62 @@ def __init__(self, data_path="data"):
dataset = "ogbn-proteins"
super(OGBProteinsDataset, self).__init__(data_path, dataset)

@property
def edge_attr_size(self):
return [
self.data.edge_attr.shape[1],
]

def get_loss_fn(self):
return bce_with_logits_loss

def get_evaluator(self):
evaluator = NodeEvaluator(name="ogbn-proteins")

def wrap(y_pred, y_true):
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)["rocauc"]

return wrap

def process(self):
name = self.name.replace("_", "-")
dataset = NodePropPredDataset(name, self.root)
graph, y = dataset[0]
y = torch.tensor(y.squeeze())
row, col = graph["edge_index"][0], graph["edge_index"][1]
row = torch.from_numpy(row)
col = torch.from_numpy(col)
edge_attr = torch.as_tensor(graph["edge_feat"]) if "edge_feat" in graph else None

data = Graph(x=None, edge_index=(row, col), edge_attr=edge_attr, y=y)
data.num_nodes = graph["num_nodes"]

# split
split_index = dataset.get_idx_split()
data.train_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)
data.val_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)
data.test_mask = torch.full((data.num_nodes,), False, dtype=torch.bool)

data.train_mask[split_index["train"]] = True
data.test_mask[split_index["test"]] = True
data.val_mask[split_index["valid"]] = True

edge_attr = data.edge_attr
deg = data.degrees()
dst, _ = data.edge_index
dst = dst.view(-1, 1).expand(dst.shape[0], edge_attr.shape[1])
x = torch.zeros((data.num_nodes, edge_attr.shape[1]), dtype=torch.float32)
x = x.scatter_add_(dim=0, index=dst, src=edge_attr)
deg = torch.clamp(deg, min=1)
x = x / deg.view(-1, 1)
data.x = x

data.node_species = torch.as_tensor(graph["node_species"])

torch.save(data, self.processed_paths[0])
return data


@register_dataset("ogbn-papers100M")
class OGBPapers100MDataset(OGBNDataset):
Expand Down
6 changes: 3 additions & 3 deletions cogdl/datasets/planetoid_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def edge_index_from_dict(graph_dict, num_nodes=None):
row, col = [], []
for key, value in graph_dict.items():
row.append(np.repeat(key, len(value)))
col.append(value)
_row = np.concatenate(np.array(row))
_col = np.concatenate(np.array(col))
col.append(np.array(value))
_row = np.concatenate(row)
_col = np.concatenate(col)
edge_index = np.stack([_row, _col], axis=0)

row_dom = edge_index[:, _row > _col]
Expand Down
54 changes: 30 additions & 24 deletions cogdl/layers/deepergcn_layer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -10,22 +12,20 @@
class GENConv(nn.Module):
def __init__(
self,
in_feats,
out_feats,
aggr="softmax_sg",
beta=1.0,
p=1.0,
learn_beta=False,
learn_p=False,
use_msg_norm=False,
learn_msg_scale=True,
norm=None,
residual=False,
activation=None,
num_mlp_layers=2,
edge_attr_size=[
-1,
],
in_feats: int,
out_feats: int,
aggr: str = "softmax_sg",
beta: float = 1.0,
p: float = 1.0,
learn_beta: bool = False,
learn_p: bool = False,
use_msg_norm: bool = False,
learn_msg_scale: bool = True,
norm: Optional[str] = None,
residual: bool = False,
activation: Optional[str] = None,
num_mlp_layers: int = 2,
edge_attr_size: Optional[list] = None,
):
super(GENConv, self).__init__()
self.use_msg_norm = use_msg_norm
Expand Down Expand Up @@ -63,8 +63,8 @@ def __init__(
self.norm = None if norm is None else get_norm_layer(norm, in_feats)
self.residual = residual

if edge_attr_size[0] > 0:
if len(edge_attr_size) > 0:
if edge_attr_size is not None and edge_attr_size[0] > 0:
if len(edge_attr_size) > 1:
self.edge_encoder = BondEncoder(edge_attr_size, in_feats)
else:
self.edge_encoder = EdgeEncoder(edge_attr_size[0], in_feats)
Expand Down Expand Up @@ -145,35 +145,41 @@ def __init__(
dropout=0.0,
out_norm=None,
out_channels=-1,
residual=True,
checkpoint_grad=False,
):
super(ResGNNLayer, self).__init__()
self.conv = conv
self.activation = get_activation(activation)
self.dropout = dropout
self.norm = get_norm_layer(norm, in_channels)
self.checkpoint_grad = checkpoint_grad
self.residual = residual
if out_norm:
self.out_norm = get_norm_layer(norm, out_channels)
else:
self.out_norm = None
self.checkpoint_grad = False

def forward(self, graph, x, dropout=None):
def forward(self, graph, x, dropout=None, *args, **kwargs):
h = self.norm(x)
h = self.activation(h)
if isinstance(dropout, float) or dropout is None:
h = F.dropout(h, p=self.dropout, training=self.training)
else:
if self.training:
h = h * dropout

if self.checkpoint_grad:
h = checkpoint(self.conv, graph, h)
h = checkpoint(self.conv, graph, h, *args, **kwargs)
else:
h = self.conv(graph, h)
h = self.conv(graph, h, *args, **kwargs)
if self.residual:
h = h + x

if self.out_norm:
return self.out_norm(x + h)
return self.out_norm(h)
else:
return x + h
return h


class EdgeEncoder(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions cogdl/match.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ node_classification:
- pubmed
- ogbn-arxiv
- ogbn-products
- ogbn-proteins
- ogbn-papers100M
- flickr
- amazon-s
Expand Down
1 change: 0 additions & 1 deletion cogdl/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def predict(self, data):
def node_classification_loss(self, data, mask=None):
if mask is None:
mask = data.train_mask
assert mask.shape[0] == data.y.shape[0]
pred = self.forward(data)
return self.loss_fn(pred[mask], data.y[mask])

Expand Down
3 changes: 3 additions & 0 deletions cogdl/models/nn/deepergcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def build_model_from_args(cls, args):
learn_p=args.learn_p,
learn_msg_scale=args.learn_msg_scale,
use_msg_norm=args.use_msg_norm,
edge_attr_size=args.edge_attr_size,
)

def __init__(
Expand All @@ -60,6 +61,7 @@ def __init__(
learn_p=False,
learn_msg_scale=True,
use_msg_norm=False,
edge_attr_size=None,
):
super(DeeperGCN, self).__init__()
self.dropout = dropout
Expand All @@ -80,6 +82,7 @@ def __init__(
learn_p=learn_p,
use_msg_norm=use_msg_norm,
learn_msg_scale=learn_msg_scale,
edge_attr_size=edge_attr_size,
),
in_channels=hidden_size,
activation=activation,
Expand Down
1 change: 0 additions & 1 deletion cogdl/models/nn/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def __init__(
]
)
self.num_layers = num_layers
self.dropout = dropout

def embed(self, graph):
graph.sym_norm()
Expand Down
Loading

0 comments on commit c4d36c2

Please sign in to comment.