Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]Remove dependence on pyg #169

Merged
merged 1 commit into from
Jan 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cogdl/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def num_edges(self):
@property
def num_features(self):
r"""Returns the number of features per node in the graph."""
if self.x is None:
return 0
return 1 if self.x.dim() == 1 else self.x.size(1)

@property
Expand Down
261 changes: 260 additions & 1 deletion cogdl/datasets/pyg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
import glob
import os
import os.path as osp
import shutil
import zipfile
from itertools import repeat

import numpy as np
import torch
import torch.nn.functional as F

from torch_geometric.datasets import Planetoid, Reddit, TUDataset, QM9
from cogdl.data.dataset import Dataset
from . import register_dataset
from ..data import Data
from ..utils import download_url


def normalize_feature(data):
Expand All @@ -14,6 +24,255 @@ def normalize_feature(data):
return data


def parse_txt_array(src, sep=None, start=0, end=None, dtype=None, device=None):
src = [[float(x) for x in line.split(sep)[start:end]] for line in src]
src = torch.tensor(src, dtype=dtype).squeeze()
return src


def read_txt_array(path, sep=None, start=0, end=None, dtype=None, device=None):
with open(path, 'r') as f:
src = f.read().split('\n')[:-1]
return parse_txt_array(src, sep, start, end, dtype, device)


def read_file(folder, prefix, name, dtype=None):
path = osp.join(folder, '{}_{}.txt'.format(prefix, name))
return read_txt_array(path, sep=',', dtype=dtype)


def cat(seq):
seq = [item for item in seq if item is not None]
seq = [item.unsqueeze(-1) if item.dim() == 1 else item for item in seq]
return torch.cat(seq, dim=-1) if len(seq) > 0 else None


def split(data, batch):
node_slice = torch.cumsum(torch.from_numpy(np.bincount(batch)), 0)
node_slice = torch.cat([torch.tensor([0]), node_slice])

row, _ = data.edge_index
edge_slice = torch.cumsum(torch.from_numpy(np.bincount(batch[row])), 0)
edge_slice = torch.cat([torch.tensor([0]), edge_slice])

# Edge indices should start at zero for every graph.
data.edge_index -= node_slice[batch[row]].unsqueeze(0)
data.__num_nodes__ = torch.bincount(batch).tolist()

slices = {'edge_index': edge_slice}
if data.x is not None:
slices['x'] = node_slice
if data.edge_attr is not None:
slices['edge_attr'] = edge_slice
if data.y is not None:
if data.y.size(0) == batch.size(0):
slices['y'] = node_slice
else:
slices['y'] = torch.arange(0, batch[-1] + 2, dtype=torch.long)

return data, slices


def segment(src, indptr):
out_list = []
for i in range(indptr.size(-1) - 1):
indexptr = torch.arange(indptr[..., i].item(), indptr[..., i + 1].item(), dtype=torch.int64)
src_data = src.index_select(indptr.dim() - 1, indexptr)
out = torch.sum(src_data, dim=indptr.dim() - 1, keepdim=True)
out_list.append(out)
return torch.cat(out_list, dim=indptr.dim() - 1)


def coalesce(index, value, m, n):
row = index[0]
col = index[1]

idx = col.new_zeros(col.numel() + 1)
idx[1:] = row
idx[1:] *= n
idx[1:] += col
if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort()
row = row[perm]
col = col[perm]
if value is not None:
value = value[perm]

idx = col.new_full((col.numel() + 1,), -1)
idx[1:] = n * row + col
mask = idx[1:] > idx[:-1]

if mask.all(): # Skip if indices are already coalesced.
return torch.stack([row, col], dim=0), value

row = row[mask]
col = col[mask]

if value is not None:
ptr = mask.nonzero().flatten()
ptr = torch.cat([ptr, ptr.new_full((1,), value.size(0))])
# print(value.size(), ptr.size())
value = segment(value, ptr)
# print(value.size())
value = value[0] if isinstance(value, tuple) else value

return torch.stack([row, col], dim=0), value


def read_tu_data(folder, prefix):
files = glob.glob(osp.join(folder, '{}_*.txt'.format(prefix)))
names = [f.split(os.sep)[-1][len(prefix) + 1:-4] for f in files]

edge_index = read_file(folder, prefix, 'A', torch.long).t() - 1
batch = read_file(folder, prefix, 'graph_indicator', torch.long) - 1

node_attributes = node_labels = None
if 'node_attributes' in names:
node_attributes = read_file(folder, prefix, 'node_attributes')
if 'node_labels' in names:
node_labels = read_file(folder, prefix, 'node_labels', torch.long)
if node_labels.dim() == 1:
node_labels = node_labels.unsqueeze(-1)
node_labels = node_labels - node_labels.min(dim=0)[0]
node_labels = node_labels.unbind(dim=-1)
node_labels = [F.one_hot(x, num_classes=-1) for x in node_labels]
node_labels = torch.cat(node_labels, dim=-1).to(torch.float)
x = cat([node_attributes, node_labels])

edge_attributes, edge_labels = None, None
if 'edge_attributes' in names:
edge_attributes = read_file(folder, prefix, 'edge_attributes')
if 'edge_labels' in names:
edge_labels = read_file(folder, prefix, 'edge_labels', torch.long)
if edge_labels.dim() == 1:
edge_labels = edge_labels.unsqueeze(-1)
edge_labels = edge_labels - edge_labels.min(dim=0)[0]
edge_labels = edge_labels.unbind(dim=-1)
edge_labels = [F.one_hot(e, num_classes=-1) for e in edge_labels]
edge_labels = torch.cat(edge_labels, dim=-1).to(torch.float)
edge_attr = cat([edge_attributes, edge_labels])

y = None
if 'graph_attributes' in names: # Regression problem.
y = read_file(folder, prefix, 'graph_attributes')
elif 'graph_labels' in names: # Classification problem.
y = read_file(folder, prefix, 'graph_labels', torch.long)
_, y = y.unique(sorted=True, return_inverse=True)

num_nodes = edge_index.max().item() + 1 if x is None else x.size(0)

mask = edge_index[0] != edge_index[1]
edge_index = edge_index[:, mask]
if edge_attr is not None:
edge_attr = edge_attr[mask]

edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes, num_nodes)

data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
data, slices = split(data, batch)

return data, slices


class TUDataset(Dataset):
url = 'https://www.chrsmrrs.com/graphkerneldatasets'

def __init__(self, root, name):
self.name = name
super(TUDataset, self).__init__(root)
self.data, self.slices = torch.load(self.processed_paths[0])
if self.data.x is not None:
num_node_attributes = self.num_node_attributes
self.data.x = self.data.x[:, num_node_attributes:]
if self.data.edge_attr is not None:
num_edge_attributes = self.num_edge_attributes
self.data.edge_attr = self.data.edge_attr[:, num_edge_attributes:]

@property
def raw_file_names(self):
names = ['A', 'graph_indicator']
return ['{}_{}.txt'.format(self.name, name) for name in names]

@property
def processed_file_names(self):
return 'data.pt'

def download(self):
url = self.url
folder = osp.join(self.root)
path = download_url('{}/{}.zip'.format(url, self.name), folder)
with zipfile.ZipFile(path, 'r') as f:
f.extractall(folder)
os.unlink(path)
shutil.rmtree(self.raw_dir)
os.rename(osp.join(folder, self.name), self.raw_dir)

def process(self):
self.data = read_tu_data(self.raw_dir, self.name)
torch.save(self.data, self.processed_paths[0])

@property
def num_node_labels(self):
if self.data.x is None:
return 0
for i in range(self.data.x.size(1)):
x = self.data.x[:, i:]
if ((x == 0) | (x == 1)).all() and (x.sum(dim=1) == 1).all():
return self.data.x.size(1) - i
return 0

@property
def num_node_attributes(self):
if self.data.x is None:
return 0
return self.data.x.size(1) - self.num_node_labels

@property
def num_edge_labels(self):
if self.data.edge_attr is None:
return 0
for i in range(self.data.edge_attr.size(1)):
if self.data.edge_attr[:, i:].sum() == self.data.edge_attr.size(0):
return self.data.edge_attr.size(1) - i
return 0

@property
def num_edge_attributes(self):
if self.data.edge_attr is None:
return 0
return self.data.edge_attr.size(1) - self.num_edge_labels

@property
def num_classes(self):
r"""The number of classes in the dataset."""
y = self.data.y
return y.max().item() + 1 if y.dim() == 1 else y.size(1)

def __len__(self):
for item in self.slices.values():
return len(item) - 1
return 0

def get(self, idx):
data = self.data.__class__()
if hasattr(self.data, '__num_nodes__'):
data.num_nodes = self.data.__num_nodes__[idx]

for key in self.data.keys:
item, slices = self.data[key], self.slices[key]
start, end = slices[idx].item(), slices[idx + 1].item()
if torch.is_tensor(item):
s = list(repeat(slice(None), item.dim()))
s[self.data.cat_dim(key, item)] = slice(start, end)
elif start + 1 == end:
s = slices[start]
else:
s = slice(start, end)
data[key] = item[s]

return data


@register_dataset("mutag")
class MUTAGDataset(TUDataset):
def __init__(self, args=None):
Expand Down
4 changes: 2 additions & 2 deletions cogdl/tasks/graph_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def generate_data(self, dataset, args):
return {"train": train_set, "test": test_set}
else:
datalist = []
if isinstance(dataset[0], Data):
return dataset
# if isinstance(dataset[0], Data):
# return dataset
for idata in dataset:
data = Data()
for key in idata.keys:
Expand Down