forked from hyp1231/GMPT
-
Notifications
You must be signed in to change notification settings - Fork 3
/
bio_loader.py
55 lines (47 loc) · 2.19 KB
/
bio_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
from torch_geometric.data import InMemoryDataset
class BioDataset(InMemoryDataset):
def __init__(self,
root,
data_type,
empty=False,
transform=None,
pre_transform=None,
pre_filter=None):
"""
Adapted from qm9.py. Disabled the download functionality
:param root: the data directory that contains a raw and processed dir
:param data_type: either supervised or unsupervised
:param empty: if True, then will not load any data obj. For
initializing empty dataset
:param transform:
:param pre_transform:
:param pre_filter:
"""
self.root = root
self.data_type = data_type
super(BioDataset, self).__init__(root, transform, pre_transform, pre_filter)
if not empty:
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
#raise NotImplementedError('Data is assumed to be processed')
if self.data_type == 'supervised': # 8 labelled species
file_name_list = ['3702', '6239', '511145', '7227', '9606', '10090', '4932', '7955']
else: # unsupervised: 8 labelled species, and 42 top unlabelled species by n_nodes.
file_name_list = ['3702', '6239', '511145', '7227', '9606', '10090',
'4932', '7955', '3694', '39947', '10116', '443255', '9913', '13616',
'3847', '4577', '8364', '9823', '9615', '9544', '9796', '3055', '7159',
'9031', '7739', '395019', '88036', '9685', '9258', '9598', '485913',
'44689', '9593', '7897', '31033', '749414', '59729', '536227', '4081',
'8090', '9601', '749927', '13735', '448385', '457427', '3711', '479433',
'479432', '28377', '9646']
return file_name_list
@property
def processed_file_names(self):
return 'geometric_data_processed.pt'
def download(self):
raise NotImplementedError('Must indicate valid location of raw data. '
'No download allowed')
def process(self):
raise NotImplementedError('Data is assumed to be processed')