diff --git a/cogdl/datasets/__init__.py b/cogdl/datasets/__init__.py index 2bf9615e..813c8a40 100644 --- a/cogdl/datasets/__init__.py +++ b/cogdl/datasets/__init__.py @@ -109,6 +109,10 @@ def build_dataset_from_path(data_path, dataset=None): "ogbg-molpcba": "cogdl.datasets.ogb.OGBMolpcbaDataset", "ogbg-ppa": "cogdl.datasets.ogb.OGBPpaDataset", "ogbg-code": "cogdl.datasets.ogb.OGBCodeDataset", + "ogbl-ppa": "cogdl.datasets.ogb.OGBLPpaDataset", + "ogbl-ddi": "cogdl.datasets.ogb.OGBLDdiDataset", + "ogbl-collab": "cogdl.datasets.ogb.OGBLCollabDataset", + "ogbl-citation2": "cogdl.datasets.ogb.OGBLCitation2Dataset", "amazon": "cogdl.datasets.gatne.AmazonDataset", "twitter": "cogdl.datasets.gatne.TwitterDataset", "youtube": "cogdl.datasets.gatne.YouTubeDataset", diff --git a/cogdl/datasets/ogb.py b/cogdl/datasets/ogb.py index a5b41c73..1df6e55a 100644 --- a/cogdl/datasets/ogb.py +++ b/cogdl/datasets/ogb.py @@ -4,6 +4,7 @@ from ogb.nodeproppred import NodePropPredDataset from ogb.nodeproppred import Evaluator as NodeEvaluator from ogb.graphproppred import GraphPropPredDataset +from ogb.linkproppred import LinkPropPredDataset from cogdl.data import Dataset, Graph, DataLoader from cogdl.utils import CrossEntropyLoss, Accuracy, remove_self_loops, coalesce, BCEWithLogitsLoss @@ -234,3 +235,85 @@ class OGBCodeDataset(OGBGDataset): def __init__(self, data_path="data"): dataset = "ogbg-code" super(OGBCodeDataset, self).__init__(data_path, dataset) + + +#This part is for ogbl datasets + +class OGBLDataset(Dataset): + def __init__(self, root, name): + """ + - name (str): name of the dataset + - root (str): root directory to store the dataset folder + """ + + self.name = name + + dataset = LinkPropPredDataset(name, root) + graph= dataset[0] + x = torch.tensor(graph["node_feat"]).contiguous() if graph["node_feat"] is not None else None + row, col = graph["edge_index"][0], graph["edge_index"][1] + row = torch.from_numpy(row) + col = torch.from_numpy(col) + edge_index = torch.stack([row, col], dim=0) + edge_attr = torch.as_tensor(graph["edge_feat"]) if graph["edge_feat"] is not None else graph["edge_feat"] + edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) + row = torch.cat([edge_index[0], edge_index[1]]) + col = torch.cat([edge_index[1], edge_index[0]]) + + row, col, _ = coalesce(row, col) + edge_index = torch.stack([row, col], dim=0) + + self.data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr, y=None) + self.data.num_nodes = graph["num_nodes"] + + def get(self, idx): + assert idx == 0 + return self.data + + def get_loss_fn(self): + return CrossEntropyLoss() + + def get_evaluator(self): + return Accuracy() + + def _download(self): + pass + + @property + def processed_file_names(self): + return "data_cogdl.pt" + + def _process(self): + pass + + def get_edge_split(self): + idx = self.dataset.get_edge_split() + train_edge = torch.from_numpy(idx['train']['edge'].T) + val_edge = torch.from_numpy(idx['valid']['edge'].T) + test_edge = torch.from_numpy(idx['test']['edge'].T) + return train_edge, val_edge, test_edge + +class OGBLPpaDataset(OGBLDataset): + def __init__(self, data_path="data"): + dataset = "ogbl-ppa" + super(OGBLPpaDataset, self).__init__(data_path, dataset) + + +class OGBLCollabDataset(OGBLDataset): + def __init__(self, data_path="data"): + dataset = "ogbl-collab" + super(OGBLCollabDataset, self).__init__(data_path, dataset) + + +class OGBLDdiDataset(OGBLDataset): + def __init__(self, data_path="data"): + dataset = "ogbl-ddi" + super(OGBLDdiDataset, self).__init__(data_path, dataset) + + +class OGBLCitation2Dataset(OGBLDataset): + def __init__(self, data_path="data"): + dataset = "ogbl-citation2" + super(OGBLCitation2Dataset, self).__init__(data_path, dataset) + + diff --git a/tests/datasets/test_ogb.py b/tests/datasets/test_ogb.py index 64cb10fc..d491244a 100644 --- a/tests/datasets/test_ogb.py +++ b/tests/datasets/test_ogb.py @@ -18,7 +18,22 @@ def test_ogbg_molhiv(): assert dataset.all_nodes == 1049163 assert len(dataset.data) == 41127 +def test_ogbl_ddi(): + args = build_args_from_dict({"dataset": "ogbl-ddi"}) + assert args.dataset == "ogbl-ddi" + dataset = build_dataset(args) + data = dataset.data + assert data.num_nodes == 4267 +def test_ogbl_collab(): + args = build_args_from_dict({"dataset": "ogbl-collab"}) + assert args.dataset == "ogbl-collab" + dataset = build_dataset(args) + data = dataset.data + assert data.num_nodes == 235868 + if __name__ == "__main__": test_ogbn_arxiv() test_ogbg_molhiv() + test_ogbl_ddi() + test_ogbl_collab() \ No newline at end of file