-
Notifications
You must be signed in to change notification settings - Fork 11
/
PlanarSATPairsDataset.py
42 lines (31 loc) · 1.23 KB
/
PlanarSATPairsDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torch_geometric.data import InMemoryDataset
import pickle
import os
from torch_geometric.utils import to_networkx
NAME = "GRAPHSAT"
class PlanarSATPairsDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
super(PlanarSATPairsDataset, self).__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return [NAME+".pkl"]
@property
def processed_file_names(self):
return 'data.pt'
def download(self):
pass
def process(self):
# Read data into huge `Data` list.
data_list = pickle.load(open(os.path.join(self.root, "raw/"+NAME+".pkl"), "rb"))
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
if __name__ == "__main__":
test_path = "Data/EXP/"
dataset = PlanarSATPairsDataset(test_path)
print(dataset[0])