-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_preprocessing.py
197 lines (158 loc) · 8.09 KB
/
data_preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import os
import numpy as np
import cv2
import nibabel as nib
from glob import glob
from tqdm import tqdm
from albumentations import Resize, Normalize
from logger import setup_logger
from exception import CustomException
# Setup logger
logger = setup_logger("log_process_____121212")
# Create a directory
def create_dir(path):
try:
if not os.path.exists(path):
os.makedirs(path)
logger.info(f"Created directory: {path}")
except Exception as e:
logger.error(f"Error creating directory {path}: {str(e)}")
raise CustomException(f"Error creating directory {path}: {str(e)}")
# Load data from HGG and LGG directories
def load_data(path):
try:
train_images, mask_images = [], []
hgg_path = os.path.join(path, 'Training', 'HGG')
lgg_path = os.path.join(path, 'Training', 'LGG')
hgg_patients = sorted(glob(os.path.join(hgg_path, "*")))
lgg_patients = sorted(glob(os.path.join(lgg_path, "*")))
all_patients = hgg_patients + lgg_patients
logger.info(f"Found {len(all_patients)} training patients in total.")
for patient_dir in all_patients:
flair = glob(os.path.join(patient_dir, "*_flair.nii.gz"))
t1 = glob(os.path.join(patient_dir, "*_t1.nii.gz"))
t1ce = glob(os.path.join(patient_dir, "*_t1ce.nii.gz"))
t2 = glob(os.path.join(patient_dir, "*_t2.nii.gz"))
seg = glob(os.path.join(patient_dir, "*_seg.nii.gz"))
if flair and t1 and t1ce and t2 and seg:
train_images.append((flair[0], t1[0], t1ce[0], t2[0]))
mask_images.append(seg[0])
else:
logger.warning(f"Missing modalities for patient: {os.path.basename(patient_dir)}")
logger.info(f"Loaded {len(train_images)} images and {len(mask_images)} masks.")
val_path = os.path.join(path, 'Validation')
val_patients = sorted(glob(os.path.join(val_path, "*")))
logger.info(f"Found {len(val_patients)} validation patients in total.")
val_images, val_masks = [], []
for patient_dir in val_patients:
flair = glob(os.path.join(patient_dir, "*_flair.nii.gz"))
t1 = glob(os.path.join(patient_dir, "*_t1.nii.gz"))
t1ce = glob(os.path.join(patient_dir, "*_t1ce.nii.gz"))
t2 = glob(os.path.join(patient_dir, "*_t2.nii.gz"))
seg = glob(os.path.join(patient_dir, "*_seg.nii.gz"))
if flair and t1 and t1ce and t2 and seg:
val_images.append((flair[0], t1[0], t1ce[0], t2[0]))
val_masks.append(seg[0])
else:
logger.warning(f"Missing modalities for validation patient: {os.path.basename(patient_dir)}")
return (train_images, mask_images), (val_images, val_masks)
except Exception as e:
logger.error(f"Error loading data: {str(e)}")
raise CustomException(f"Error loading data: {str(e)}")
# Read NIfTI images
def read_nii(file_path):
try:
return nib.load(file_path).get_fdata()
except Exception as e:
logger.error(f"Error reading NIfTI file {file_path}: {str(e)}")
raise CustomException(f"Error reading NIfTI file {file_path}: {str(e)}")
# Preprocessing function
def preprocess_image(image, mask):
try:
# Normalize and resize (necessary preprocessing)
aug = Normalize(mean=(0.0, 0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0, 1.0))
normalized = aug(image=image, mask=mask)
image = normalized["image"]
mask = normalized["mask"]
resize = Resize(512, 512)
resized = resize(image=image, mask=mask)
return resized["image"], resized["mask"]
except Exception as e:
logger.error(f"Error in preprocessing: {str(e)}")
raise CustomException(f"Error in preprocessing: {str(e)}")
# Save preprocessed data without augmentation
def save_preprocessed_data(images, masks, save_path):
try:
total_saved = 0
for idx, (modalities, mask) in tqdm(enumerate(zip(images, masks)), total=len(images), desc="Processing Images"):
patient_name = os.path.basename(os.path.dirname(modalities[0]))
# Read image modalities and mask
flair = read_nii(modalities[0])
t1 = read_nii(modalities[1])
t1ce = read_nii(modalities[2])
t2 = read_nii(modalities[3])
mask = read_nii(mask)
# Stack modalities into a single multi-channel array
image = np.stack([flair, t1, t1ce, t2], axis=-1)
slice_idx = image.shape[2] // 2 # Middle slice
image_slice = image[:, :, slice_idx, :]
mask_slice = mask[:, :, slice_idx]
# Preprocess the image and mask
image_slice, mask_slice = preprocess_image(image_slice, mask_slice)
# Save the preprocessed image and mask
tmp_image_name = f"{patient_name}.png"
tmp_mask_name = f"{patient_name}.png"
image_path = os.path.join(save_path, "image", tmp_image_name)
mask_path = os.path.join(save_path, "mask", tmp_mask_name)
create_dir(os.path.dirname(image_path))
create_dir(os.path.dirname(mask_path))
success_image = cv2.imwrite(image_path, (image_slice * 255).astype(np.uint8))
success_mask = cv2.imwrite(mask_path, (mask_slice * 255).astype(np.uint8))
if success_image and success_mask:
total_saved += 1
logger.info(f"Saved: {image_path} and {mask_path}")
else:
logger.warning(f"Failed to save: {image_path} and/or {mask_path}")
logger.info(f"Total images saved: {total_saved}")
except Exception as e:
logger.error(f"Error during data saving: {str(e)}")
raise CustomException(f"Error during data saving: {str(e)}")
# Main workflow
if __name__ == "__main__":
try:
np.random.seed(42)
# Load the data
data_path = r'data/raw'
(train_images, mask_images), (val_images, val_masks) = load_data(data_path)
# Create directories to save the preprocessed data
create_dir(os.path.join("new_data", "train", "image"))
create_dir(os.path.join("new_data", "train", "mask"))
create_dir(os.path.join("new_data", "val", "image"))
create_dir(os.path.join("new_data", "val", "mask"))
logger.info("Starting data processing...")
# Process training data (without augmentation)
save_preprocessed_data(train_images, mask_images, os.path.join("new_data", "train"))
# Save validation data as .png files
for idx, (modalities, mask) in tqdm(enumerate(zip(val_images, val_masks)), total=len(val_images), desc="Processing Validation Images"):
patient_name = os.path.basename(os.path.dirname(modalities[0]))
flair = read_nii(modalities[0])
t1 = read_nii(modalities[1])
t1ce = read_nii(modalities[2])
t2 = read_nii(modalities[3])
mask = read_nii(mask)
# Save only one middle slice for the mask
slice_idx = flair.shape[2] // 2
mask_slice = mask[:, :, slice_idx]
# Save the middle slice of the image modalities
for modality, mod_name in zip([flair, t1, t1ce, t2], ["flair", "t1", "t1ce", "t2"]):
tmp_name = f"{patient_name}_{mod_name}.png"
save_path = os.path.join("new_data", "val", "image", tmp_name)
cv2.imwrite(save_path, (modality[:, :, slice_idx] * 255).astype(np.uint8))
# Save the mask
mask_name = f"{patient_name}_mask.png"
mask_path = os.path.join("new_data", "val", "mask", mask_name)
cv2.imwrite(mask_path, (mask_slice * 255).astype(np.uint8))
logger.info("Validation data saved successfully.")
except Exception as e:
logger.error(f"Error in the main workflow: {str(e)}")
raise CustomException(f"Error in the main workflow: {str(e)}")