Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update ogbl datasets #358

Merged
merged 9 commits into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cogdl/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def build_dataset_from_path(data_path, dataset=None):
"ogbg-molpcba": "cogdl.datasets.ogb.OGBMolpcbaDataset",
"ogbg-ppa": "cogdl.datasets.ogb.OGBPpaDataset",
"ogbg-code": "cogdl.datasets.ogb.OGBCodeDataset",
"ogbl-ppa": "cogdl.datasets.ogb.OGBLPpaDataset",
"ogbl-ddi": "cogdl.datasets.ogb.OGBLDdiDataset",
"ogbl-collab": "cogdl.datasets.ogb.OGBLCollabDataset",
"ogbl-citation2": "cogdl.datasets.ogb.OGBLCitation2Dataset",
"amazon": "cogdl.datasets.gatne.AmazonDataset",
"twitter": "cogdl.datasets.gatne.TwitterDataset",
"youtube": "cogdl.datasets.gatne.YouTubeDataset",
Expand Down
83 changes: 83 additions & 0 deletions cogdl/datasets/ogb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ogb.nodeproppred import NodePropPredDataset
from ogb.nodeproppred import Evaluator as NodeEvaluator
from ogb.graphproppred import GraphPropPredDataset
from ogb.linkproppred import LinkPropPredDataset

from cogdl.data import Dataset, Graph, DataLoader
from cogdl.utils import CrossEntropyLoss, Accuracy, remove_self_loops, coalesce, BCEWithLogitsLoss
Expand Down Expand Up @@ -234,3 +235,85 @@ class OGBCodeDataset(OGBGDataset):
def __init__(self, data_path="data"):
dataset = "ogbg-code"
super(OGBCodeDataset, self).__init__(data_path, dataset)


#This part is for ogbl datasets

class OGBLDataset(Dataset):
def __init__(self, root, name):
"""
- name (str): name of the dataset
- root (str): root directory to store the dataset folder
"""

self.name = name

dataset = LinkPropPredDataset(name, root)
graph= dataset[0]
x = torch.tensor(graph["node_feat"]).contiguous() if graph["node_feat"] is not None else None
row, col = graph["edge_index"][0], graph["edge_index"][1]
row = torch.from_numpy(row)
col = torch.from_numpy(col)
edge_index = torch.stack([row, col], dim=0)
edge_attr = torch.as_tensor(graph["edge_feat"]) if graph["edge_feat"] is not None else graph["edge_feat"]
edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
row = torch.cat([edge_index[0], edge_index[1]])
col = torch.cat([edge_index[1], edge_index[0]])

row, col, _ = coalesce(row, col)
edge_index = torch.stack([row, col], dim=0)

self.data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr, y=None)
self.data.num_nodes = graph["num_nodes"]

def get(self, idx):
assert idx == 0
return self.data

def get_loss_fn(self):
return CrossEntropyLoss()

def get_evaluator(self):
return Accuracy()

def _download(self):
pass

@property
def processed_file_names(self):
return "data_cogdl.pt"

def _process(self):
pass

def get_edge_split(self):
idx = self.dataset.get_edge_split()
train_edge = torch.from_numpy(idx['train']['edge'].T)
val_edge = torch.from_numpy(idx['valid']['edge'].T)
test_edge = torch.from_numpy(idx['test']['edge'].T)
return train_edge, val_edge, test_edge

class OGBLPpaDataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-ppa"
super(OGBLPpaDataset, self).__init__(data_path, dataset)


class OGBLCollabDataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-collab"
super(OGBLCollabDataset, self).__init__(data_path, dataset)


class OGBLDdiDataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-ddi"
super(OGBLDdiDataset, self).__init__(data_path, dataset)


class OGBLCitation2Dataset(OGBLDataset):
def __init__(self, data_path="data"):
dataset = "ogbl-citation2"
super(OGBLCitation2Dataset, self).__init__(data_path, dataset)


15 changes: 15 additions & 0 deletions tests/datasets/test_ogb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,22 @@ def test_ogbg_molhiv():
assert dataset.all_nodes == 1049163
assert len(dataset.data) == 41127

def test_ogbl_ddi():
args = build_args_from_dict({"dataset": "ogbl-ddi"})
assert args.dataset == "ogbl-ddi"
dataset = build_dataset(args)
data = dataset.data
assert data.num_nodes == 4267

def test_ogbl_collab():
args = build_args_from_dict({"dataset": "ogbl-collab"})
assert args.dataset == "ogbl-collab"
dataset = build_dataset(args)
data = dataset.data
assert data.num_nodes == 235868

if __name__ == "__main__":
test_ogbn_arxiv()
test_ogbg_molhiv()
test_ogbl_ddi()
test_ogbl_collab()