forked from pranayj77/medfusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset_simple_2d.py
executable file
·198 lines (166 loc) · 8.34 KB
/
dataset_simple_2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import torch.utils.data as data
import torch
from torch import nn
from pathlib import Path
from torchvision import transforms as T
import pandas as pd
from PIL import Image
from medical_diffusion.data.augmentation.augmentations_2d import Normalize, ToTensor16bit
class SimpleDataset2D(data.Dataset):
def __init__(
self,
path_root,
item_pointers =[],
crawler_ext = 'tif', # other options are ['jpg', 'jpeg', 'png', 'tiff'],
transform = None,
image_resize = None,
augment_horizontal_flip = False,
augment_vertical_flip = False,
image_crop = None,
):
super().__init__()
self.path_root = Path(path_root)
self.crawler_ext = crawler_ext
if len(item_pointers):
self.item_pointers = item_pointers
else:
self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_ext)
if transform is None:
self.transform = T.Compose([
T.Resize(image_resize) if image_resize is not None else nn.Identity(),
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
T.RandomVerticalFlip() if augment_vertical_flip else nn.Identity(),
T.CenterCrop(image_crop) if image_crop is not None else nn.Identity(),
T.ToTensor(),
# T.Lambda(lambda x: torch.cat([x]*3) if x.shape[0]==1 else x),
# ToTensor16bit(),
# Normalize(), # [0, 1.0]
# T.ConvertImageDtype(torch.float),
T.Normalize(mean=0.5, std=0.5) # WARNING: mean and std are not the target values but rather the values to subtract and divide by: [0, 1] -> [0-0.5, 1-0.5]/0.5 -> [-1, 1]
])
else:
self.transform = transform
def __len__(self):
return len(self.item_pointers)
def __getitem__(self, index):
rel_path_item = self.item_pointers[index]
path_item = self.path_root/rel_path_item
# img = Image.open(path_item)
img = self.load_item(path_item)
return {'uid':rel_path_item.stem, 'source': self.transform(img)}
def load_item(self, path_item):
return Image.open(path_item).convert('RGB')
# return cv2.imread(str(path_item), cv2.IMREAD_UNCHANGED) # NOTE: Only CV2 supports 16bit RGB images
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
return [path.relative_to(path_root) for path in Path(path_root).rglob(f'*.{extension}')]
def get_weights(self):
"""Return list of class-weights for WeightedSampling"""
return None
class AIROGSDataset(SimpleDataset2D):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.labels = pd.read_csv(self.path_root.parent/'train_labels.csv', index_col='challenge_id')
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
uid = self.labels.index[index]
path_item = self.path_root/f'{uid}.jpg'
img = self.load_item(path_item)
str_2_int = {'NRG':0, 'RG':1} # RG = 3270, NRG = 98172
target = str_2_int[self.labels.loc[uid, 'class']]
# return {'uid':uid, 'source': self.transform(img), 'target':target}
return {'source': self.transform(img), 'target':target}
def get_weights(self):
n_samples = len(self)
weight_per_class = 1/self.labels['class'].value_counts(normalize=True) # {'NRG': 1.03, 'RG': 31.02}
weights = [0] * n_samples
for index in range(n_samples):
target = self.labels.iloc[index]['class']
weights[index] = weight_per_class[target]
return weights
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
"""Overwrite to speed up as paths are determined by .csv file anyway"""
return []
class MSIvsMSS_Dataset(SimpleDataset2D):
# https://doi.org/10.5281/zenodo.2530835
def __getitem__(self, index):
rel_path_item = self.item_pointers[index]
path_item = self.path_root/rel_path_item
img = self.load_item(path_item)
uid = rel_path_item.stem
str_2_int = {'MSIMUT':0, 'MSS':1}
target = str_2_int[path_item.parent.name] #
return {'uid':uid, 'source': self.transform(img), 'target':target}
class MSIvsMSS_2_Dataset(SimpleDataset2D):
# https://doi.org/10.5281/zenodo.3832231
def __getitem__(self, index):
rel_path_item = self.item_pointers[index]
path_item = self.path_root/rel_path_item
img = self.load_item(path_item)
uid = rel_path_item.stem
str_2_int = {'MSIH':0, 'nonMSIH':1} # patients with MSI-H = MSIH; patients with MSI-L and MSS = NonMSIH)
target = str_2_int[path_item.parent.name]
# return {'uid':uid, 'source': self.transform(img), 'target':target}
return {'source': self.transform(img), 'target':target}
class CheXpert_Dataset(SimpleDataset2D):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
mode = self.path_root.name
labels = pd.read_csv(self.path_root.parent/f'{mode}.csv', index_col='Path')
self.labels = labels.loc[labels['Frontal/Lateral'] == 'Frontal'].copy()
self.labels.index = self.labels.index.str[20:]
self.labels.loc[self.labels['Sex'] == 'Unknown', 'Sex'] = 'Female' # Affects 1 case, must be "female" to match stats in publication
self.labels.fillna(2, inplace=True) # TODO: Find better solution,
str_2_int = {'Sex': {'Male':0, 'Female':1}, 'Frontal/Lateral':{'Frontal':0, 'Lateral':1}, 'AP/PA':{'AP':0, 'PA':1}}
self.labels.replace(str_2_int, inplace=True)
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
rel_path_item = self.labels.index[index]
path_item = self.path_root/rel_path_item
img = self.load_item(path_item)
uid = str(rel_path_item)
target = torch.tensor(self.labels.loc[uid, 'Cardiomegaly']+1, dtype=torch.long) # Note Labels are -1=uncertain, 0=negative, 1=positive, NA=not reported -> Map to [0, 2], NA=3
return {'uid':uid, 'source': self.transform(img), 'target':target}
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
"""Overwrite to speed up as paths are determined by .csv file anyway"""
return []
class CheXpert_2_Dataset(SimpleDataset2D):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
labels = pd.read_csv(self.path_root/'labels/cheXPert_label.csv', index_col=['Path', 'Image Index']) # Note: 1 and -1 (uncertain) cases count as positives (1), 0 and NA count as negatives (0)
labels = labels.loc[labels['fold']=='train'].copy()
labels = labels.drop(labels='fold', axis=1)
labels2 = pd.read_csv(self.path_root/'labels/train.csv', index_col='Path')
labels2 = labels2.loc[labels2['Frontal/Lateral'] == 'Frontal'].copy()
labels2 = labels2[['Cardiomegaly',]].copy()
labels2[ (labels2 <0) | labels2.isna()] = 2 # 0 = Negative, 1 = Positive, 2 = Uncertain
labels = labels.join(labels2['Cardiomegaly'], on=["Path",], rsuffix='_true')
# labels = labels[labels['Cardiomegaly_true']!=2]
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
path_index, image_index = self.labels.index[index]
path_item = self.path_root/'data'/f'{image_index:06}.png'
img = self.load_item(path_item)
uid = image_index
target = int(self.labels.loc[(path_index, image_index), 'Cardiomegaly'])
# return {'uid':uid, 'source': self.transform(img), 'target':target}
return {'source': self.transform(img), 'target':target}
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
"""Overwrite to speed up as paths are determined by .csv file anyway"""
return []
def get_weights(self):
n_samples = len(self)
weight_per_class = 1/self.labels['Cardiomegaly'].value_counts(normalize=True)
# weight_per_class = {2.0: 1.2, 1.0: 8.2, 0.0: 24.3}
weights = [0] * n_samples
for index in range(n_samples):
target = self.labels.loc[self.labels.index[index], 'Cardiomegaly']
weights[index] = weight_per_class[target]
return weights