Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reorder_like to speed up the averaging of edge weights for undirected graphs #8

Merged
merged 1 commit into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions example/gsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import scipy
import torch
import torch.nn as nn
from torch_sparse import transpose
from torch_geometric.utils import is_undirected
from utils import MLP
from utils import MLP, reorder_like


class GSAT(nn.Module):
Expand Down Expand Up @@ -40,9 +41,9 @@ def forward_pass(self, data, epoch, training):

if self.learn_edge_att:
if is_undirected(data.edge_index):
nodesize = data.x.shape[0]
sci_csr = scipy.sparse.csr_matrix((torch.arange(att.shape[0]), (data.edge_index[0].cpu(), data.edge_index[1].cpu())), (nodesize, nodesize))
edge_att = (att + att[sci_csr[data.edge_index[1].tolist(), data.edge_index[0].tolist()].A1]) / 2
trans_idx, trans_val = transpose(data.edge_index, att, None, None, coalesced=False)
trans_val_perm = reorder_like(trans_idx, data.edge_index, trans_val)
edge_att = (att + trans_val_perm) / 2
else:
edge_att = att
else:
Expand Down
11 changes: 6 additions & 5 deletions src/run_gsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.utils import subgraph, is_undirected
from torch_sparse import transpose
from torch_geometric.loader import DataLoader
from torch_geometric.utils import subgraph, is_undirected
from ogb.graphproppred import Evaluator
from sklearn.metrics import roc_auc_score
from rdkit import Chem

from pretrain_clf import train_clf_one_seed
from utils import Writer, Criterion, MLP, visualize_a_graph, save_checkpoint, load_checkpoint, get_preds, get_lr, set_seed, process_data
from utils import get_local_config_name, get_model, get_data_loaders, write_stat_from_metric_dicts, init_metric_dict
from utils import get_local_config_name, get_model, get_data_loaders, write_stat_from_metric_dicts, reorder_like, init_metric_dict


class GSAT(nn.Module):
Expand Down Expand Up @@ -75,9 +76,9 @@ def forward_pass(self, data, epoch, training):

if self.learn_edge_att:
if is_undirected(data.edge_index):
nodesize = data.x.shape[0]
sp_csr = scipy.sparse.csr_matrix((torch.arange(att.shape[0]), (data.edge_index[0].cpu(), data.edge_index[1].cpu())), (nodesize, nodesize))
edge_att = (att + att[sp_csr[data.edge_index[1].tolist(), data.edge_index[0].tolist()].A1]) / 2
trans_idx, trans_val = transpose(data.edge_index, att, None, None, coalesced=False)
trans_val_perm = reorder_like(trans_idx, data.edge_index, trans_val)
edge_att = (att + trans_val_perm) / 2
else:
edge_att = att
else:
Expand Down
11 changes: 10 additions & 1 deletion src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from rdkit import Chem
import matplotlib.pyplot as plt
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
from torch_geometric.utils import to_networkx, sort_edge_index
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.summary import hparams

Expand All @@ -16,6 +16,15 @@
'metric/best_x_precision_train': 0, 'metric/best_x_precision_valid': 0, 'metric/best_x_precision_test': 0}


def reorder_like(from_edge_index, to_edge_index, values):
from_edge_index, values = sort_edge_index(from_edge_index, values)
ranking_score = to_edge_index[0] * (to_edge_index.max()+1) + to_edge_index[1]
ranking = ranking_score.argsort().argsort()
if not (from_edge_index[:, ranking] == to_edge_index).all():
raise ValueError("Edges in from_edge_index and to_edge_index are different, impossible to match both.")
return values[ranking]


def process_data(data, use_edge_attr):
if not use_edge_attr:
data.edge_attr = None
Expand Down