-
Notifications
You must be signed in to change notification settings - Fork 11
/
load_data.py
95 lines (83 loc) · 3 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import networkx as nx
import numpy as np
import scipy as sc
import os
import re
import util
def read_graphfile(datadir, dataname, max_nodes=None):
prefix = os.path.join(datadir, dataname, dataname)
filename_graph_indic = prefix + '_graph_indicator.txt'
graph_indic={}
with open(filename_graph_indic) as f:
i=1
for line in f:
line=line.strip("\n")
graph_indic[i]=int(line)
i+=1
filename_nodes=prefix + '_node_labels.txt'
node_labels=[]
try:
with open(filename_nodes) as f:
for line in f:
line=line.strip("\n")
node_labels+=[int(line) - 1]
num_unique_node_labels = max(node_labels) + 1
except IOError:
print('No node labels')
filename_node_attrs=prefix + '_node_attributes.txt'
node_attrs=[]
try:
with open(filename_node_attrs) as f:
for line in f:
line = line.strip("\s\n")
attrs = [float(attr) for attr in re.split("[,\s]+", line) if not attr == '']
node_attrs.append(np.array(attrs))
except IOError:
print('No node attributes')
label_has_zero = False
filename_graphs=prefix + '_graph_labels.txt'
graph_labels=[]
label_vals = []
with open(filename_graphs) as f:
for line in f:
line=line.strip("\n")
val = int(line)
if val not in label_vals:
label_vals.append(val)
graph_labels.append(val)
label_map_to_int = {val: i for i, val in enumerate(label_vals)}
graph_labels = np.array([label_map_to_int[l] for l in graph_labels])
filename_adj=prefix + '_A.txt'
adj_list={i:[] for i in range(1,len(graph_labels)+1)}
index_graph={i:[] for i in range(1,len(graph_labels)+1)}
num_edges = 0
with open(filename_adj) as f:
for line in f:
line=line.strip("\n").split(",")
e0,e1=(int(line[0].strip(" ")),int(line[1].strip(" ")))
adj_list[graph_indic[e0]].append((e0,e1))
index_graph[graph_indic[e0]]+=[e0,e1]
num_edges += 1
for k in index_graph.keys():
index_graph[k]=[u-1 for u in set(index_graph[k])]
graphs=[]
for i in range(1,1+len(adj_list)):
G=nx.from_edgelist(adj_list[i])
G.graph['label'] = graph_labels[i-1]
for u in util.node_iter(G):
if len(node_labels) > 0:
node_label_one_hot = [0] * num_unique_node_labels
node_label = node_labels[u-1]
node_label_one_hot[node_label] = 1
util.node_dict(G)[u]['label'] = node_label_one_hot
if len(node_attrs) > 0:
util.node_dict(G)[u]['feat'] = node_attrs[u-1]
if len(node_attrs) > 0:
G.graph['feat_dim'] = node_attrs[0].shape[0]
mapping={}
it=0
for n in util.node_iter(G):
mapping[n]=it
it+=1
graphs.append(nx.relabel_nodes(G, mapping))
return graphs