-
Notifications
You must be signed in to change notification settings - Fork 0
/
loader.py
80 lines (65 loc) · 2.31 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# 0 ad
# 1 mci
# 2 cn
import pandas as pd
from torch.utils.data.dataset import Dataset
import nibabel as nib
from torchvision import transforms
import torch
import numpy as np
from torch.utils.data import DataLoader
# 平衡数据集
def recreate(df):
img_oversample0 = []
img_oversample1 = []
df_oversample0 = pd.DataFrame()
for i in range(len(df)):
if df.iloc[i][2] == 0:
img_oversample0.append(df.iloc[i][1])
img_oversample0.append(df.iloc[i][1])
df_oversample0['path'] = img_oversample0
df_oversample0['label'] = [0] * len(img_oversample0)
df_oversample1 = pd.DataFrame()
for i in range(len(df)):
if df.iloc[i][2] == 1:
img_oversample1.append(df.iloc[i][1])
img_oversample1.append(df.iloc[i][1])
df_oversample1['path'] = img_oversample1
df_oversample1['label'] = [1] * len(img_oversample1)
df = pd.concat([df, df_oversample0], axis=0, ignore_index=True)
df = pd.concat([df, df_oversample1], axis=0, ignore_index=True)
return df
class newADdataset(Dataset):
def __init__(self, csv_path, phase):
self.df = pd.read_csv(csv_path)
if phase == 'train':
self.df=recreate(self.df)
def __getitem__(self, idx):
img_path = self.df.iloc[idx][1]
# print(img_path)
label = self.df.iloc[idx][2]
img = nib.load(img_path).get_fdata()
# 图像尺寸压缩
# resize_transform = transforms.Resize((80, 100, 76))
# img = resize_transform(img)
img = img[40:120, 30:130, 10:86]
img = np.squeeze(img)
img = img*1.0
img = (img-img.min())/(img.max()-img.min())
img = torch.from_numpy(img)
# img = torch.Tensor(1*[img.tolist()])
img = img.unsqueeze(0).float()
# 返回图像chwd为1*80*100*76
return img, label
def __len__(self):
return len(self.df['label'])
def newADdataloader(phase, csv_path, batch_size):
path = csv_path
print(path)
ADdataset = newADdataset(path, phase)
if(phase == 'train'):
dataloader = DataLoader(dataset=ADdataset, batch_size=batch_size, num_workers=0, shuffle=True)
else:
dataloader = DataLoader(dataset=ADdataset, batch_size=batch_size, num_workers=0, shuffle=False)
# print('dataloader')
return dataloader