diff --git a/python/main.py b/python/main.py
index bed4d67..557ff56 100644
--- a/python/main.py
+++ b/python/main.py
@@ -12,7 +12,6 @@
import socket
import subprocess
import sys
-import time
import unittest
from functools import lru_cache
from io import BytesIO
@@ -86,7 +85,7 @@ class BoxData(BaseModel):
normalized_start: dict[str, float]
normalized_end: dict[str, float]
seg_class: int
- inp_idx: int
+ dcm_hash: str
class BoxDataResponse(BaseModel):
@@ -116,6 +115,11 @@ class UploadFilesResponse(BaseModel):
skip_deidentification: bool
+class SubmitResponse(BaseModel):
+ dicom_data_fps: list[Any]
+ dcm_hashes: list[str]
+
+
def dcm2dictmetadata(ds: pydicom.dataset.Dataset) -> dict[str, dict[str, str]]:
ds_metadata_dict = {}
for ds_attr in ds:
@@ -189,7 +193,7 @@ def test_submit_button(self: TestEndpoints) -> None:
)
hasher = hashlib.sha256()
block_size = 65536
- with Path(json_response[0][1]).open("rb") as file:
+ with Path(json_response["dicom_data_fps"][0][1]).open("rb") as file:
buf = file.read(block_size)
while len(buf) > 0:
hasher.update(buf)
@@ -202,8 +206,8 @@ def test_submit_button(self: TestEndpoints) -> None:
)
-@app.get("/check_existence_of_clean")
-async def check_existence_of_clean() -> UploadFilesResponse:
+@app.get("/get_clean_cache")
+async def get_clean_cache() -> UploadFilesResponse:
session_fp = "./tmp/session-data/clean/de-identified-files/session.json"
proper_dicom_paths = sorted(
glob.glob( # noqa: PTH207
@@ -792,7 +796,7 @@ async def medsam_estimation(boxdata: BoxData) -> BoxDataResponse:
start = boxdata.normalized_start
end = boxdata.normalized_end
seg_class = boxdata.seg_class
- inp_idx = boxdata.inp_idx
+ dcm_hash = boxdata.dcm_hash
bbox = np.array(
[
min(start["x"], end["x"]),
@@ -802,54 +806,71 @@ async def medsam_estimation(boxdata: BoxData) -> BoxDataResponse:
],
)
box_256 = bbox[None, :] * 256
- time.time()
medsam_model = load_model()
- temp_dir = Path("./tmp/session-data/embed")
- embedding = torch.load(temp_dir / f"embed_{inp_idx}.pt")
- hs = np.load(temp_dir / "Hs.npy")
- ws = np.load(temp_dir / "Ws.npy")
+ embed_dp = Path("./tmp/session-data/embed")
+ embedding = torch.load(embed_dp / f"{dcm_hash}.pt")
+ with open(file="./tmp/session-data/shape.json") as f: # noqa: PTH123, ASYNC101
+ h, w = json.load(f)[dcm_hash]
medsam_seg = medsam_inference(
medsam_model,
embedding,
box_256,
(256, 256),
- (hs[inp_idx], ws[inp_idx]),
+ (h, w),
)
medsam_seg = (seg_class * medsam_seg).astype(np.uint8)
return BoxDataResponse(
mask=base64.b64encode(medsam_seg).decode("utf-8"),
- dimensions=[int(ws[inp_idx]), int(hs[inp_idx])],
+ dimensions=[w, h],
)
def prepare_medsam() -> None:
+ def initialize_image_embeddings(dcm_hash, img) -> None: # type: ignore[no-untyped-def] # noqa: ANN001
+ img_embeding_fp = embed_dp / f"{dcm_hash}.pt"
+ if not os.path.exists(img_embeding_fp):
+ two_d = 2
+ img_3c = (
+ np.repeat(img[:, :, None], 3, axis=-1) if len(img.shape) == two_d else img
+ )
+ img_256 = cv2.resize(src=img_3c, dsize=(256, 256)).astype(np.float32)
+ img_256 = (img_256 - img_256.min()) / np.clip(
+ img_256.max() - img_256.min(),
+ a_min=1e-8,
+ a_max=None,
+ )
+ img_256_tensor = torch.tensor(img_256).float().permute(2, 0, 1).unsqueeze(0)
+ with torch.no_grad():
+ embedding = medsam_model.image_encoder(img_256_tensor)
+ torch.save(embedding, img_embeding_fp)
+
medsam_model = load_model()
- raw_fp = Path("./tmp/session-data/raw")
- dcm_fps = sorted(raw_fp.glob("*"))
- temp_dir = Path("./tmp/session-data/embed")
- hs, ws = [], []
- for idx, dcm_fp in enumerate(dcm_fps):
+ raw_dp = Path("./tmp/session-data/raw")
+ embed_dp = Path("./tmp/session-data/embed")
+ clean_dp = Path("./tmp/session-data/clean")
+ dcm_fps = list(raw_dp.glob("**/*.dcm")) + list(clean_dp.glob("**/*.dcm"))
+ embed_fps = glob.glob(os.path.join(str(embed_dp), "**/*.pt")) # noqa: PTH207, PTH118
+ embed_hashes = [embed_fp.split("/")[-1].split(".")[0] for embed_fp in embed_fps]
+ shapes = {}
+ for dcm_fp in dcm_fps:
+ if "raw" in str(dcm_fp):
+ with open(file=dcm_fp, mode="rb") as f: # noqa: PTH123
+ dcm_hash = hashlib.sha256(f.read()).hexdigest()
+ else:
+ dcm_hash = str(dcm_fp).split("/")[-1].split(".")[0]
img = pydicom.dcmread(dcm_fp).pixel_array
- two_d = 2
- img_3c = (
- np.repeat(img[:, :, None], 3, axis=-1) if len(img.shape) == two_d else img
- )
- h, w, _ = img_3c.shape
- hs.append(h)
- ws.append(w)
- img_256 = cv2.resize(src=img_3c, dsize=(256, 256)).astype(np.float32)
- img_256 = (img_256 - img_256.min()) / np.clip(
- img_256.max() - img_256.min(),
- a_min=1e-8,
- a_max=None,
- )
- img_256_tensor = torch.tensor(img_256).float().permute(2, 0, 1).unsqueeze(0)
- with torch.no_grad():
- embedding = medsam_model.image_encoder(img_256_tensor)
- torch.save(embedding, temp_dir / f"embed_{idx}.pt")
-
- np.save(temp_dir / "Hs.npy", np.array(hs))
- np.save(temp_dir / "Ws.npy", np.array(ws))
+ if len(img.shape) == 2: # noqa: PLR2004
+ h, w = img.shape
+ else:
+ h, w = img.shape[:-1]
+ shapes[dcm_hash] = [h, w]
+ if dcm_hash in embed_hashes:
+ continue
+ else: # noqa: RET507
+ initialize_image_embeddings(dcm_hash, img)
+ embed_hashes.append(dcm_hash)
+ with open(file=Path("./tmp/session-data") / "shape.json", mode="w") as f: # type: ignore[assignment] # noqa: PTH123
+ json.dump(shapes, f) # type: ignore[arg-type]
def deidentification_attributes(
@@ -1210,7 +1231,7 @@ def __next__(
if self.n_dicom_files - 1 >= self.DICOM_IDX:
self.raw_dicom_fp, self.clean_dicom_fp = self.dicom_pair_fps[self.DICOM_IDX]
self.deintentify = self.pending_deidentification[self.DICOM_IDX]
- self.raw_dicom_hash = self.raw_dicom_hashes[self.DICOM_IDX]
+ self.dcm_hash = self.raw_dicom_hashes[self.DICOM_IDX]
return True
return False
@@ -1226,7 +1247,7 @@ def define_undefined_clean_dicom_fp(self, clean_dcm) -> None: # type: ignore[no
)
if not Path(clean_dicom_dp).exists():
Path(clean_dicom_dp).mkdir(parents=True)
- clean_dicom_fp = os.path.join(clean_dicom_dp, self.raw_dicom_hash + ".dcm") # noqa: PTH118
+ clean_dicom_fp = os.path.join(clean_dicom_dp, self.dcm_hash + ".dcm") # noqa: PTH118
self.dicom_pair_fps[self.DICOM_IDX][1] = self.clean_dicom_fp = (
clean_dicom_fp
)
@@ -1237,9 +1258,9 @@ def export_processed_data(
bbox_img: NDArray[Any],
) -> None:
if bbox_img is not None:
- bbox_img_fp = os.path.join(self.out_dp, self.raw_dicom_hash + "_bbox.png") # noqa: PTH118
+ bbox_img_fp = os.path.join(self.out_dp, self.dcm_hash + "_bbox.png") # noqa: PTH118
Image.fromarray(bbox_img).save(bbox_img_fp) # type: ignore[no-untyped-call]
- cache_bbox_img(self.raw_dicom_hash)
+ cache_bbox_img(self.dcm_hash)
dcm.save_as(self.clean_dicom_fp)
def export_session(
@@ -1253,7 +1274,7 @@ def export_session(
def dicom_deidentifier(
session: dict, # type: ignore[type-arg]
-) -> tuple[dict[str, dict[str, str]], list[tuple[str]]]:
+) -> tuple[dict[str, dict[str, str]], list[tuple[str]], list[str]]:
if Path("./tmp/session-data/custom-config.csv").is_file():
custom_config_df = pd.read_csv(
filepath_or_buffer="./tmp/session-data/custom-config.csv",
@@ -1338,11 +1359,11 @@ def dicom_deidentifier(
rw_obj.define_undefined_clean_dicom_fp(dcm)
rw_obj.export_processed_data(dcm=dcm, bbox_img=bbox_img) # type: ignore[arg-type]
rw_obj.export_session(session=session)
- return session, rw_obj.dicom_pair_fps # type: ignore[return-value]
+ return session, rw_obj.dicom_pair_fps, rw_obj.raw_dicom_hashes # type: ignore[return-value]
@app.post("/submit_button")
-async def handle_submit_button_click(user_options: UserOptionsClass) -> list[Any]:
+async def handle_submit_button_click(user_options: UserOptionsClass) -> SubmitResponse:
user_options = dict(user_options) # type: ignore[assignment]
user_options["input_dcm_dp"] = "./tmp/session-data/raw" # type: ignore[index]
user_options["output_dcm_dp"] = "./tmp/session-data/clean" # type: ignore[index]
@@ -1357,12 +1378,13 @@ async def handle_submit_button_click(user_options: UserOptionsClass) -> list[Any
"./tmp/session-data/clean/de-identified-files/session.json",
).open() as file:
session = json.load(file)
- session, dicom_pair_fps = dicom_deidentifier( # type: ignore[assignment]
+ session, dicom_data_fps, dcm_hashes = dicom_deidentifier( # type: ignore[assignment]
session=session,
)
if user_options["annotation"]: # type: ignore[index]
prepare_medsam()
- return dicom_pair_fps
+
+ return {"dicom_data_fps": dicom_data_fps, "dcm_hashes": dcm_hashes} # type: ignore[return-value]
def generate_action_groups() -> None:
diff --git a/python/static/script.js b/python/static/script.js
index f6faa0c..c4cfcfb 100644
--- a/python/static/script.js
+++ b/python/static/script.js
@@ -17,7 +17,7 @@ var ctx = OverlayCanvas.getContext('2d');
var ToggleEdit = document.getElementById('ToggleEdit');
var BrushSizeSlider = document.getElementById('BrushSizeSlider');
var BrushSelect = document.getElementById('BrushSelect');
-var LoadDICOM = document.getElementById('ResetDICOM');
+var LoadMask = document.getElementById('ResetMask');
var ExportMasks = document.getElementById('ExportMasks');
var Undo = document.getElementById('Undo');
var Redo = document.getElementById('Redo');
@@ -43,6 +43,7 @@ var current_dicom_data_fp;
var OpenSequences = [];
var DiffEnabled = false;
var dcm_idx_;
+var dcm_hash;
var isEditing = false;
var currentBrush = 'background';
var brushSize = 25;
@@ -56,6 +57,7 @@ var BoxStart = null;
var BoxEnd = null;
var progress_saved = true;
var notificationTimeout;
+var dcm_hashes;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
const colorMap = {
@@ -327,7 +329,8 @@ async function UpdateDICOMInformation(dcm_idx)
slider_pending_update = false;
LoadingState = true;
dcm_idx_ = dcm_idx;
- current_dicom_data_fp = await dicom_data_fps[dcm_idx_]
+ dcm_hash = dcm_hashes[dcm_idx_];
+ current_dicom_data_fp = await dicom_data_fps[dcm_idx_];
DisplayRadio.value = "cleaned-display-option";
const conversion_info_response = await fetch
(
@@ -400,14 +403,14 @@ nextSlice.addEventListener("click", function () {
});
document.addEventListener("DOMContentLoaded", function() {
- check_existence_of_clean();
+ get_clean_cache();
});
-async function check_existence_of_clean()
+async function get_clean_cache()
{
const dcm_files_response = await fetch
(
- '/check_existence_of_clean',
+ '/get_clean_cache',
{
method: 'GET'
}
@@ -569,7 +572,7 @@ async function submit_dicom_processing_request()
'retain_descriptors': retain_descriptors_input_checkbox.checked,
'patient_pseudo_id_prefix': patient_pseudo_id_prefix_input_text.value
};
- const dicom_data_fps_response = await fetch
+ const dicom_data_response = await fetch
(
'/submit_button',
{
@@ -582,7 +585,9 @@ async function submit_dicom_processing_request()
}
);
- dicom_data_fps = await dicom_data_fps_response.json();
+ dicom_data = await dicom_data_response.json();
+ dicom_data_fps = dicom_data.dicom_data_fps
+ dcm_hashes = dicom_data.dcm_hashes
if (clean_image.checked) {
DisplayRadio.disabled=false;
}
@@ -835,7 +840,7 @@ async function medsam_estimation(normalizedStart,normalizedEnd) {
normalized_start: normalizedStart,
normalized_end: normalizedEnd,
seg_class: classesMap.indexOf(BrushSelect.value),
- inp_idx: dcm_idx_
+ dcm_hash: dcm_hash,
};
const box_response = await fetch(
'/medsam_estimation/',
@@ -992,13 +997,13 @@ async function submit_classes(){
BrushSizeSlider.disabled = false;
Undo.disabled = false;
Redo.disabled = false;
- LoadDICOM.disabled = false;
+ LoadMask.disabled = false;
ExportMasks.disabled = false;
Add.disabled = true;
Remove.disabled = true;
ClassText.disabled = true;
SubmitClasses.disabled = true;
- DisplayRadio.disabled=false;
+ DisplayRadio.disabled = false;
if (classesMap.length !== predefinedClassesMap.length && predefinedClassesMap.length !== 1)
{
var optionModal = new bootstrap.Modal(document.getElementById('optionModal'), {
@@ -1034,6 +1039,16 @@ async function submit_classes(){
});
get_mask_from_file();
}
+ if (classesMap.length === 1)
+ {
+ ToggleEdit.disabled = true;
+ ToggleMask.disabled = true;
+ BrushSizeSlider.disabled = true;
+ Undo.disabled = true;
+ Redo.disabled = true;
+ LoadMask.disabled = true;
+ ExportMasks.disabled = true;
+ }
showNotification("success", "Submitted classes", 1500);
}
@@ -1148,7 +1163,7 @@ function resetGUIElements() {
BrushSizeSlider.disabled = true;
Undo.disabled = true;
Redo.disabled = true;
- LoadDICOM.disabled = true;
+ LoadMask.disabled = true;
ExportMasks.disabled = true;
Add.disabled = true;
Remove.disabled = true;
diff --git a/python/templates/index.html b/python/templates/index.html
index 71ea532..3c70036 100644
--- a/python/templates/index.html
+++ b/python/templates/index.html
@@ -298,7 +298,7 @@