-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
145 lines (119 loc) · 5.55 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import numpy as np
import onnxruntime as rt
import pydicom
from pydicom.uid import generate_uid
from PIL import Image, ImageOps
import scipy.ndimage
import sys, os, json
from pathlib import Path
# Load the model from the ONNX file
model = rt.InferenceSession('export.onnx')
input_name = model.get_inputs()[0].name
def process_image(file, in_folder, out_folder, series_uid, settings):
# Compose the filename of the output DICOM file using the new series UID
out_filename = series_uid + "#" + file.split("#", 1)[1]
dcm_file_out = Path(out_folder) / out_filename
# Read the source DICOM slice
dcm_file_in = Path(in_folder) / file
dcm = pydicom.dcmread(dcm_file_in)
# Normalize input image
dcmpixel = dcm.pixel_array
dcmpixel = ( 1.0 / dcmpixel.max() * (dcmpixel - dcmpixel.min()) )
# Ensure that the input image has square size
if (dcmpixel.shape[0] != dcmpixel.shape[1]):
print("Error: Width and height are not equal. Not supported by this module")
return False
# Scale input to 288 pixels, as needed by the DL model
inference_zoom = 288/float(dcmpixel.shape[0])
scl_dcmpixel = scipy.ndimage.zoom(dcmpixel, inference_zoom, order=3)
# Shape data for running inference
scl_dcmpixel=np.dstack([scl_dcmpixel]*3)
scl_dcmpixel=np.rollaxis(scl_dcmpixel,2)
scl_dcmpixel=scl_dcmpixel.astype(np.float32)
scl_dcmpixel=np.expand_dims(scl_dcmpixel, axis=0)
# Execute the inference via ONNX Runtime
outputs = model.run(None, {input_name: scl_dcmpixel})
# Get segmentation mask and scale back to original resolution
mask=outputs[0]
binary_mask = np.where(mask[0,0,:,:] > 0.5, 0, 255).astype(np.ubyte)
scl_mask = scipy.ndimage.zoom(binary_mask, 1/inference_zoom, order=3)
# Colorize segmantation mask
mask_image = Image.fromarray(scl_mask)
mask_image = ImageOps.colorize(mask_image, black="black", white=settings["color"])
# Normalize the background (input) image
background = 255 * ( 1.0 / dcmpixel.max() * (dcmpixel - dcmpixel.min()) )
background = background.astype(np.ubyte)
background_image = Image.fromarray(background).convert("RGB")
# Blend the two images
final_image = Image.blend(mask_image, background_image, settings["transparency"])
final_array = np.array(final_image).astype(np.uint8)
# Write the final image back to a new DICOM (color) image
dcm.SeriesInstanceUID = series_uid
dcm.SOPInstanceUID = generate_uid()
dcm.SeriesNumber = dcm.SeriesNumber + settings["series_offset"]
dcm.file_meta.TransferSyntaxUID = pydicom.uid.ExplicitVRLittleEndian
dcm.Rows = final_image.height
dcm.Columns = final_image.width
dcm.PhotometricInterpretation = "RGB"
dcm.SamplesPerPixel = 3
dcm.BitsStored = 8
dcm.BitsAllocated = 8
dcm.HighBit = 7
dcm.PixelRepresentation = 0
dcm.PixelData = final_array.tobytes()
dcm.SeriesDescription = "SEG(" + dcm.SeriesDescription + ")"
dcm.save_as(dcm_file_out)
return True
def main(args=sys.argv[1:]):
print("")
print("AI-based prostate-segmentation example for mercure (v 0.2)")
print("----------------------------------------------------------")
print("")
# Check if the input and output folders are provided as arguments
if len(sys.argv) < 3:
print("Error: Missing arguments!")
print("Usage: inference.py [input-folder] [output-folder]")
sys.exit(1)
# Check if the input and output folders actually exist
in_folder = sys.argv[1]
out_folder = sys.argv[2]
if not Path(in_folder).exists() or not Path(out_folder).exists():
print("IN/OUT paths do not exist")
sys.exit(1)
# Load the task.json file, containing the settings for the processing module
try:
with open(Path(in_folder) / "task.json", "r") as json_file:
task = json.load(json_file)
except Exception:
print("Error: Task file task.json not found")
sys.exit(1)
# Create default values for all module settings
settings = {"color": "yellow", "transparency": 0.75, "series_offset": 1000}
# Overwrite default values with settings from the task file (if present)
if task.get("process", ""):
settings.update(task["process"].get("settings", {}))
# Collect all DICOM series in the input folder. By convention, DICOM files provided by
# mercure have the format [series_UID]#[file_UID].dcm. Thus, by splitting the file
# name at the "#" character, the series UID can be obtained
series = {}
for entry in os.scandir(in_folder):
if entry.name.endswith(".dcm") and not entry.is_dir():
# Get the Series UID from the file name
seriesString = entry.name.split("#", 1)[0]
# If this is the first image of the series, create new file list for the series
if not seriesString in series.keys():
series[seriesString] = []
# Add the current file to the file list
series[seriesString].append(entry.name)
# Now loop over all series found
for item in series:
# Create a new series UID, which will be used for the modified DICOM series (to avoid
# collision with the original series)
print("Processing series " + item)
series_uid = generate_uid()
# Now loop over all slices of the current series and call the processing function
for image_filename in series[item]:
print("Processing slice " + image_filename)
process_image(image_filename, in_folder, out_folder, series_uid, settings)
if __name__ == "__main__":
main()