Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix bugs in datasets and modify moe_gcn #257

Merged
merged 4 commits into from
Jul 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cogdl/datasets/ogb.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ def get_evaluator(self):
evaluator = NodeEvaluator(name="ogbn-arxiv")

def wrap(y_pred, y_true):
y_pred = y_pred.argmax(dim=-1, keepdim=True)
y_true = y_true.view(-1, 1)
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)
return evaluator.eval(input_dict)["acc"]

return wrap

Expand Down
23 changes: 11 additions & 12 deletions cogdl/layers/gcn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,18 @@ def reset_parameters(self):
if self.bias is not None:
self.bias.data.zero_()

def forward(self, graph, x, flag=True):
def forward(self, graph, x):
support = torch.mm(x, self.weight)
out = spmm(graph, support)
if flag:
if self.bias is not None:
out = out + self.bias
if self.norm is not None:
out = self.norm(out)
if self.act is not None:
out = self.act(out)
if self.bias is not None:
out = out + self.bias
if self.norm is not None:
out = self.norm(out)
if self.act is not None:
out = self.act(out)

if self.residual is not None:
out = out + self.residual(x)
if self.dropout is not None:
out = self.dropout(out)
if self.residual is not None:
out = out + self.residual(x)
if self.dropout is not None:
out = self.dropout(out)
return out
6 changes: 4 additions & 2 deletions cogdl/layers/gine_layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from cogdl.utils import spmm

from . import BaseLayer

Expand Down Expand Up @@ -27,8 +28,9 @@ def __init__(self, apply_func=None, eps=0, train_eps=True):
self.apply_func = apply_func

def forward(self, graph, x):
m = self.message(x[graph.edge_index[0]], graph.edge_attr)
out = self.aggregate(graph, m)
# m = self.message(x[graph.edge_index[0]], graph.edge_attr)
# out = self.aggregate(graph, m)
out = spmm(graph, x)
out += (1 + self.eps) * x
if self.apply_func is not None:
out = self.apply_func(out)
Expand Down
2 changes: 1 addition & 1 deletion cogdl/models/nn/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def embed(self, graph):
graph.sym_norm()
h = graph.x
for i in range(self.num_layers - 1):
h = self.layers[i](graph, h, False)
h = self.layers[i](graph, h)
return h

def forward(self, graph):
Expand Down
112 changes: 85 additions & 27 deletions cogdl/models/nn/moe_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.nn.parameter import Parameter

from .. import BaseModel, register_model
from cogdl.utils import spmm
from cogdl.utils import spmm, get_activation


from fmoe import FMoETransformerMLP
Expand Down Expand Up @@ -38,11 +38,30 @@ class GraphConv(nn.Module):
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""

def __init__(self, in_features, out_features, bias=True):
def __init__(self, in_features, out_features, dropout=0.0, residual=False, norm=None, bias=True):
super(GraphConv, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
if dropout > 0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
if residual:
self.residual = nn.Linear(in_features, out_features)
else:
self.residual = None

if norm is not None:
if norm == "batchnorm":
self.norm = nn.BatchNorm1d(out_features)
elif norm == "layernorm":
self.norm = nn.LayerNorm(out_features)
else:
raise NotImplementedError
else:
self.norm = None

if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
Expand All @@ -59,43 +78,50 @@ def forward(self, graph, x):
support = torch.mm(x, self.weight)
out = spmm(graph, support)
if self.bias is not None:
return out + self.bias
else:
return out
out = out + self.bias

if self.residual is not None:
res = self.residual(x)
if self.act is not None:
res = self.act(res)
out = out + res

def __repr__(self):
return self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")"
if self.dropout is not None:
out = self.dropout(out)

if self.norm is not None:
out = self.norm(out)

return out


class GraphConvBlock(nn.Module):
def __init__(self, in_feats, out_feats, activation=None, dropout=0.0):
def __init__(self, conv_func, conv_params, out_feats, dropout=0.0, in_feats=None, residual=False):
super(GraphConvBlock, self).__init__()

self.activation = activation
self.graph_conv = GraphConv(in_features=in_feats, out_features=out_feats)
self.dropout = nn.Dropout(dropout)
self.res_connection = nn.Linear(in_feats, out_feats)
self.bn_layer_1 = nn.BatchNorm1d(out_feats)
self.bn_layer_2 = nn.BatchNorm1d(out_feats)
self.graph_conv = conv_func(**conv_params)
self.pos_ff = CustomizedMoEPositionwiseFF(out_feats, out_feats * 2, dropout, moe_num_expert=64, moe_top_k=2)
self.dropout = dropout
if residual is True:
assert in_feats is not None
self.res_connection = nn.Linear(in_feats, out_feats)
else:
self.res_connection = None

def reset_parameters(self):
"""Reinitialize model parameters."""
self.graph_conv.reset_parameters()
# self.graph_conv.reset_parameters()
self.res_connection.reset_parameters()
self.bn_layer_1.reset_parameters()
self.bn_layer_2.reset_parameters()

def forward(self, graph, feats):
new_feats = self.graph_conv(graph, feats)
res_feats = self.res_connection(feats)
if self.activation is not None:
res_feats = self.activation(res_feats)
new_feats = new_feats + res_feats
new_feats = self.dropout(new_feats)
new_feats = self.bn_layer_1(new_feats)
if self.res_connection is not None:
res = self.res_connection
new_feats = new_feats + res
new_feats = F.dropout(new_feats, p=self.dropout, training=self.training)

new_feats = self.pos_ff(new_feats)
new_feats = self.act(new_feats)

return new_feats

Expand All @@ -121,20 +147,52 @@ def add_args(parser):
parser.add_argument("--num-layers", type=int, default=2)
parser.add_argument("--hidden-size", type=int, default=64)
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--no-residual", action="store_true")
parser.add_argument("--norm", type=str, default="batchnorm")
parser.add_argument("--activation", type=str, default="relu")
# fmt: on

@classmethod
def build_model_from_args(cls, args):
return cls(args.num_features, args.hidden_size, args.num_classes, args.num_layers, args.dropout)
return cls(
args.num_features,
args.hidden_size,
args.num_classes,
args.num_layers,
args.dropout,
args.activation,
not args.no_residual,
args.norm,
)

def __init__(self, in_feats, hidden_size, out_feats, num_layers, dropout):
def __init__(
self, in_feats, hidden_size, out_feats, num_layers, dropout, activation="relu", residual=True, norm=None
):
super(MoEGCN, self).__init__()
shapes = [in_feats] + [hidden_size] * (num_layers - 1) + [out_feats]
conv_func = GraphConv
conv_params = {
"in_features": in_feats,
"out_features": out_feats,
"dropout": dropout,
"norm": norm,
"residual": residual,
"activation": activation,
}
self.layers = nn.ModuleList(
[GraphConvBlock(shapes[i], shapes[i + 1], activation=F.gelu, dropout=dropout) for i in range(num_layers)]
[
GraphConvBlock(
conv_func,
conv_params,
shapes[i + 1],
dropout=dropout if i != num_layers - 1 else 0,
)
for i in range(num_layers)
]
)
self.num_layers = num_layers
self.dropout = dropout
self.act = get_activation(activation)

def get_embeddings(self, graph):
graph.sym_norm()
Expand All @@ -151,7 +209,7 @@ def forward(self, graph):
for i in range(self.num_layers):
h = self.layers[i](graph, h)
if i != self.num_layers - 1:
h = F.relu(h)
h = self.act(h)
h = F.dropout(h, self.dropout, training=self.training)
return h

Expand Down
2 changes: 1 addition & 1 deletion cogdl/operators/sample/sample.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <torch/extension.h>
#include <torch/extension.h>
#include <pybind11/pybind11.h>
#include <iostream>
#include <vector>
Expand Down
1 change: 0 additions & 1 deletion cogdl/trainers/m3s_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sklearn.cluster import KMeans

import torch
import torch.nn.functional as F
from .base_trainer import BaseTrainer


Expand Down
2 changes: 1 addition & 1 deletion cogdl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def get_activation(act: str):
elif act == "identity":
return lambda x: x
else:
return F.relu
return lambda x: x


def cycle_index(num, shift):
Expand Down