-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathdata_preparation.py
289 lines (225 loc) · 11.2 KB
/
data_preparation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import torch
import numpy as np
import sys
from collections import defaultdict
from config import FileManager, FileSetter
from assess_performance import PerformanceAssessment
class MyDataset(torch.utils.data.Dataset):
"""Dataset for bindEmbed21DL"""
def __init__(self, samples, embeddings, seqs, labels, max_length, protein_prediction=False):
self.protein_prediction = protein_prediction
self.samples = samples
self.embeddings = embeddings
self.seqs = seqs
self.labels = labels
self.max_length = max_length
self.n_features = self.get_input_dimensions()
print('Number of input features: {}'.format(self.n_features))
def __len__(self):
return len(self.samples)
def __getitem__(self, item):
prot_id = self.samples[item]
prot_length = len(self.seqs[prot_id])
embedding = self.embeddings[prot_id]
# pad all inputs to the maximum length & add another feature to encode whether the element is a position
# in the sequence or padded
features = np.zeros((self.n_features + 1, self.max_length), dtype=np.float32)
features[:self.n_features, :prot_length] = np.transpose(embedding) # set feature maps to embedding values
features[self.n_features, :prot_length] = 1 # set last element to 1 because positions are not padded
target = np.zeros((3, self.max_length), dtype=np.float32)
target[:3, :prot_length] = np.transpose(self.labels[prot_id])
loss_mask = np.zeros((3, self.max_length), dtype=np.float32)
loss_mask[:3, :prot_length] = 1 * prot_length
if self.protein_prediction:
return features, target, loss_mask, prot_id
else:
return features, target, loss_mask
def get_input_dimensions(self):
first_key = list(self.embeddings.keys())[0]
first_embedding = self.embeddings[first_key]
return np.shape(first_embedding)[1]
class ProteinInformation(object):
@staticmethod
def get_data(ids):
"""
Get sequences, labels, and maximum length for a set of ids
:param ids:
:return:
"""
sequences = FileManager.read_fasta(FileSetter.fasta_file())
max_length = ProteinInformation.determine_max_length(sequences, ids)
labels = ProteinInformation.get_labels(ids, sequences)
return sequences, max_length, labels
@staticmethod
def get_data_predictions(ids, fasta_file):
"""
Generate dummy labels for test proteins without annotations to allow re-use of general DataLoader
:param ids:
:param fasta_file:
:return: sequences, max. length, dummy labels
"""
sequences = FileManager.read_fasta(fasta_file)
max_length = ProteinInformation.determine_max_length(sequences, ids)
labels = dict()
for i in ids:
prot_length = len(sequences[i])
binding_tensor = np.zeros([prot_length, 3], dtype=np.float32)
labels[i] = binding_tensor
return sequences, max_length, labels
@staticmethod
def determine_max_length(sequences, ids):
"""Get maximum length in set of sequences"""
max_len = 0
for i in ids:
if len(sequences[i]) > max_len:
max_len = len(sequences[i])
return max_len
@staticmethod
def get_labels(ids, sequences, file_prefix=None):
"""
Read binding residues for metal, nucleic acids, and small molecule binding
:param ids:
:param sequences:
:param file_prefix: If None, files set in FileSetter will be used
:return:
"""
labels = dict()
if file_prefix is None:
metal_residues = FileManager.read_binding_residues(FileSetter.binding_residues_by_ligand('metal'))
nuclear_residues = FileManager.read_binding_residues(FileSetter.binding_residues_by_ligand('nuclear'))
small_residues = FileManager.read_binding_residues(FileSetter.binding_residues_by_ligand('small'))
else:
metal_residues = FileManager.read_binding_residues('{}_metal.txt'.format(file_prefix))
nuclear_residues = FileManager.read_binding_residues('{}_nuclear.txt'.format(file_prefix))
small_residues = FileManager.read_binding_residues('{}_small.txt'.format(file_prefix))
for prot_id in ids:
prot_length = len(sequences[prot_id])
binding_tensor = np.zeros([prot_length, 3], dtype=np.float32)
metal_res = nuc_res = small_res = []
if prot_id in metal_residues.keys():
metal_res = metal_residues[prot_id]
if prot_id in nuclear_residues.keys():
nuc_res = nuclear_residues[prot_id]
if prot_id in small_residues.keys():
small_res = small_residues[prot_id]
metal_residues_0_ind = ProteinInformation._get_zero_based_residues(metal_res)
nuc_residues_0_ind = ProteinInformation._get_zero_based_residues(nuc_res)
small_residues_0_ind = ProteinInformation._get_zero_based_residues(small_res)
binding_tensor[metal_residues_0_ind, 0] = 1
binding_tensor[nuc_residues_0_ind, 1] = 1
binding_tensor[small_residues_0_ind, 2] = 1
labels[prot_id] = binding_tensor
return labels
@staticmethod
def _get_zero_based_residues(residues):
residues_0_ind = []
for r in residues:
residues_0_ind.append(int(r) - 1)
return residues_0_ind
class ProteinResults(object):
def __init__(self, name, bind_cutoff=0.5):
self.name = name
self.labels = np.array([])
self.predictions = np.array([])
self.bind_cutoff = bind_cutoff
# cutoff to define if label is binding/non-binding; default: 0: non-binding, 1:binding
def set_labels(self, labels):
self.labels = np.array(labels)
def set_predictions(self, predictions):
self.predictions = np.around(np.array(np.transpose(predictions)), 3)
def add_predictions(self, predictions):
self.predictions = np.add(self.predictions, np.around(predictions, 3))
def normalize_predictions(self, norm_factor):
self.predictions = np.around(self.predictions / norm_factor, 3)
def calc_num_predictions(self, cutoff):
num_predictions = np.count_nonzero(self.predictions >= cutoff, axis=0)
return num_predictions[0], num_predictions[1], num_predictions[2]
def get_bound_ligand(self, cutoff):
num_labels = np.count_nonzero(self.labels >= cutoff, axis=0)
metal = nuclear = small = False
if num_labels[0] > 0:
metal = True
if num_labels[1] > 0:
nuclear = True
if num_labels[2] > 0:
small = True
return metal, nuclear, small
def get_predictions_ligand(self, ligand):
if ligand == 'metal':
return self.predictions[:, 0]
elif ligand == 'nucleic':
return self.predictions[:, 1]
elif ligand == 'small':
return self.predictions[:, 2]
elif ligand == 'overall':
return np.amax(self.predictions, axis=1)
else:
sys.exit('{} is not a valid ligand type'.format(ligand))
def calc_performance_measurements(self, cutoff):
performance = self.calc_performance_given_labels(cutoff, self.labels)
return performance
def calc_performance_given_labels(self, cutoff, ligand_labels):
"""Calculate performance values for this protein"""
performance = defaultdict(dict)
num_ligands = np.shape(ligand_labels)[1]
if num_ligands > 1: # ligand-type assessment
# calc per-ligand assessment for multi-label prediction
for i in range(0, num_ligands):
tp = fp = tn = fn = 0
cross_pred = [0, 0, 0, 0]
for idx, lig in enumerate(ligand_labels):
if self.predictions[idx, i] >= cutoff: # predicted as binding to this ligand
cross_prediction = False
true_prediction = False
for j in range(0, num_ligands):
if i == j: # same as predicted ligand
if lig[j] >= self.bind_cutoff: # also annotated to this ligand
tp += 1
cross_pred[i] += 1
true_prediction = True
else:
fp += 1
else:
if lig[j] >= self.bind_cutoff and not true_prediction:
cross_pred[j] += 1
cross_prediction = True
if not true_prediction and not cross_prediction:
# residues is not annotated to bind any of the ligands
cross_pred[3] += 1
else:
if lig[i] >= cutoff:
fn += 1
else:
tn += 1
if i == 0:
ligand = 'metal'
elif i == 1:
ligand = 'nucleic'
else:
ligand = 'small'
bound = False
if (tp + fn) > 0:
bound = True
acc, prec, recall, f1, mcc = PerformanceAssessment.calc_performance_measurements(tp, fp, tn, fn)
# calculate performance measurements for negatives
_, neg_p, neg_r, neg_f1, _ = PerformanceAssessment.calc_performance_measurements(tn, fn, tp, fp)
performance[ligand] = {'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn, 'acc': acc, 'prec': prec,
'recall': recall, 'f1': f1, 'neg_prec': neg_p, 'neg_recall': neg_r,
'neg_f1': neg_f1, 'mcc': mcc, 'bound': bound,
'cross_prediction': cross_pred}
# get overall performance
reduced_labels = np.sum(ligand_labels > cutoff, axis=1)
if len(self.predictions.shape) == 1:
reduced_predictions = (self.predictions >= cutoff)
else:
reduced_predictions = np.sum(self.predictions >= cutoff, axis=1)
tp = np.sum(np.logical_and(reduced_labels > 0, reduced_predictions > 0))
fp = np.sum(np.logical_and(reduced_labels == 0, reduced_predictions > 0))
tn = np.sum(np.logical_and(reduced_labels == 0, reduced_predictions == 0))
fn = np.sum(np.logical_and(reduced_labels > 0, reduced_predictions == 0))
acc, prec, recall, f1, mcc = PerformanceAssessment.calc_performance_measurements(tp, fp, tn, fn)
_, neg_p, neg_r, neg_f1, _ = PerformanceAssessment.calc_performance_measurements(tn, fn, tp, fp)
performance['overall'] = {'tp': tp, 'fp': fp, 'tn': tn, 'fn': fn, 'acc': acc, 'prec': prec, 'recall': recall,
'f1': f1, 'neg_prec': neg_p, 'neg_recall': neg_r, 'neg_f1': neg_f1, 'mcc': mcc,
'bound': True, 'cross_prediction': [0, 0, 0, 0]}
return performance