diff --git a/python/main.py b/python/main.py index bed4d67..557ff56 100644 --- a/python/main.py +++ b/python/main.py @@ -12,7 +12,6 @@ import socket import subprocess import sys -import time import unittest from functools import lru_cache from io import BytesIO @@ -86,7 +85,7 @@ class BoxData(BaseModel): normalized_start: dict[str, float] normalized_end: dict[str, float] seg_class: int - inp_idx: int + dcm_hash: str class BoxDataResponse(BaseModel): @@ -116,6 +115,11 @@ class UploadFilesResponse(BaseModel): skip_deidentification: bool +class SubmitResponse(BaseModel): + dicom_data_fps: list[Any] + dcm_hashes: list[str] + + def dcm2dictmetadata(ds: pydicom.dataset.Dataset) -> dict[str, dict[str, str]]: ds_metadata_dict = {} for ds_attr in ds: @@ -189,7 +193,7 @@ def test_submit_button(self: TestEndpoints) -> None: ) hasher = hashlib.sha256() block_size = 65536 - with Path(json_response[0][1]).open("rb") as file: + with Path(json_response["dicom_data_fps"][0][1]).open("rb") as file: buf = file.read(block_size) while len(buf) > 0: hasher.update(buf) @@ -202,8 +206,8 @@ def test_submit_button(self: TestEndpoints) -> None: ) -@app.get("/check_existence_of_clean") -async def check_existence_of_clean() -> UploadFilesResponse: +@app.get("/get_clean_cache") +async def get_clean_cache() -> UploadFilesResponse: session_fp = "./tmp/session-data/clean/de-identified-files/session.json" proper_dicom_paths = sorted( glob.glob( # noqa: PTH207 @@ -792,7 +796,7 @@ async def medsam_estimation(boxdata: BoxData) -> BoxDataResponse: start = boxdata.normalized_start end = boxdata.normalized_end seg_class = boxdata.seg_class - inp_idx = boxdata.inp_idx + dcm_hash = boxdata.dcm_hash bbox = np.array( [ min(start["x"], end["x"]), @@ -802,54 +806,71 @@ async def medsam_estimation(boxdata: BoxData) -> BoxDataResponse: ], ) box_256 = bbox[None, :] * 256 - time.time() medsam_model = load_model() - temp_dir = Path("./tmp/session-data/embed") - embedding = torch.load(temp_dir / f"embed_{inp_idx}.pt") - hs = np.load(temp_dir / "Hs.npy") - ws = np.load(temp_dir / "Ws.npy") + embed_dp = Path("./tmp/session-data/embed") + embedding = torch.load(embed_dp / f"{dcm_hash}.pt") + with open(file="./tmp/session-data/shape.json") as f: # noqa: PTH123, ASYNC101 + h, w = json.load(f)[dcm_hash] medsam_seg = medsam_inference( medsam_model, embedding, box_256, (256, 256), - (hs[inp_idx], ws[inp_idx]), + (h, w), ) medsam_seg = (seg_class * medsam_seg).astype(np.uint8) return BoxDataResponse( mask=base64.b64encode(medsam_seg).decode("utf-8"), - dimensions=[int(ws[inp_idx]), int(hs[inp_idx])], + dimensions=[w, h], ) def prepare_medsam() -> None: + def initialize_image_embeddings(dcm_hash, img) -> None: # type: ignore[no-untyped-def] # noqa: ANN001 + img_embeding_fp = embed_dp / f"{dcm_hash}.pt" + if not os.path.exists(img_embeding_fp): + two_d = 2 + img_3c = ( + np.repeat(img[:, :, None], 3, axis=-1) if len(img.shape) == two_d else img + ) + img_256 = cv2.resize(src=img_3c, dsize=(256, 256)).astype(np.float32) + img_256 = (img_256 - img_256.min()) / np.clip( + img_256.max() - img_256.min(), + a_min=1e-8, + a_max=None, + ) + img_256_tensor = torch.tensor(img_256).float().permute(2, 0, 1).unsqueeze(0) + with torch.no_grad(): + embedding = medsam_model.image_encoder(img_256_tensor) + torch.save(embedding, img_embeding_fp) + medsam_model = load_model() - raw_fp = Path("./tmp/session-data/raw") - dcm_fps = sorted(raw_fp.glob("*")) - temp_dir = Path("./tmp/session-data/embed") - hs, ws = [], [] - for idx, dcm_fp in enumerate(dcm_fps): + raw_dp = Path("./tmp/session-data/raw") + embed_dp = Path("./tmp/session-data/embed") + clean_dp = Path("./tmp/session-data/clean") + dcm_fps = list(raw_dp.glob("**/*.dcm")) + list(clean_dp.glob("**/*.dcm")) + embed_fps = glob.glob(os.path.join(str(embed_dp), "**/*.pt")) # noqa: PTH207, PTH118 + embed_hashes = [embed_fp.split("/")[-1].split(".")[0] for embed_fp in embed_fps] + shapes = {} + for dcm_fp in dcm_fps: + if "raw" in str(dcm_fp): + with open(file=dcm_fp, mode="rb") as f: # noqa: PTH123 + dcm_hash = hashlib.sha256(f.read()).hexdigest() + else: + dcm_hash = str(dcm_fp).split("/")[-1].split(".")[0] img = pydicom.dcmread(dcm_fp).pixel_array - two_d = 2 - img_3c = ( - np.repeat(img[:, :, None], 3, axis=-1) if len(img.shape) == two_d else img - ) - h, w, _ = img_3c.shape - hs.append(h) - ws.append(w) - img_256 = cv2.resize(src=img_3c, dsize=(256, 256)).astype(np.float32) - img_256 = (img_256 - img_256.min()) / np.clip( - img_256.max() - img_256.min(), - a_min=1e-8, - a_max=None, - ) - img_256_tensor = torch.tensor(img_256).float().permute(2, 0, 1).unsqueeze(0) - with torch.no_grad(): - embedding = medsam_model.image_encoder(img_256_tensor) - torch.save(embedding, temp_dir / f"embed_{idx}.pt") - - np.save(temp_dir / "Hs.npy", np.array(hs)) - np.save(temp_dir / "Ws.npy", np.array(ws)) + if len(img.shape) == 2: # noqa: PLR2004 + h, w = img.shape + else: + h, w = img.shape[:-1] + shapes[dcm_hash] = [h, w] + if dcm_hash in embed_hashes: + continue + else: # noqa: RET507 + initialize_image_embeddings(dcm_hash, img) + embed_hashes.append(dcm_hash) + with open(file=Path("./tmp/session-data") / "shape.json", mode="w") as f: # type: ignore[assignment] # noqa: PTH123 + json.dump(shapes, f) # type: ignore[arg-type] def deidentification_attributes( @@ -1210,7 +1231,7 @@ def __next__( if self.n_dicom_files - 1 >= self.DICOM_IDX: self.raw_dicom_fp, self.clean_dicom_fp = self.dicom_pair_fps[self.DICOM_IDX] self.deintentify = self.pending_deidentification[self.DICOM_IDX] - self.raw_dicom_hash = self.raw_dicom_hashes[self.DICOM_IDX] + self.dcm_hash = self.raw_dicom_hashes[self.DICOM_IDX] return True return False @@ -1226,7 +1247,7 @@ def define_undefined_clean_dicom_fp(self, clean_dcm) -> None: # type: ignore[no ) if not Path(clean_dicom_dp).exists(): Path(clean_dicom_dp).mkdir(parents=True) - clean_dicom_fp = os.path.join(clean_dicom_dp, self.raw_dicom_hash + ".dcm") # noqa: PTH118 + clean_dicom_fp = os.path.join(clean_dicom_dp, self.dcm_hash + ".dcm") # noqa: PTH118 self.dicom_pair_fps[self.DICOM_IDX][1] = self.clean_dicom_fp = ( clean_dicom_fp ) @@ -1237,9 +1258,9 @@ def export_processed_data( bbox_img: NDArray[Any], ) -> None: if bbox_img is not None: - bbox_img_fp = os.path.join(self.out_dp, self.raw_dicom_hash + "_bbox.png") # noqa: PTH118 + bbox_img_fp = os.path.join(self.out_dp, self.dcm_hash + "_bbox.png") # noqa: PTH118 Image.fromarray(bbox_img).save(bbox_img_fp) # type: ignore[no-untyped-call] - cache_bbox_img(self.raw_dicom_hash) + cache_bbox_img(self.dcm_hash) dcm.save_as(self.clean_dicom_fp) def export_session( @@ -1253,7 +1274,7 @@ def export_session( def dicom_deidentifier( session: dict, # type: ignore[type-arg] -) -> tuple[dict[str, dict[str, str]], list[tuple[str]]]: +) -> tuple[dict[str, dict[str, str]], list[tuple[str]], list[str]]: if Path("./tmp/session-data/custom-config.csv").is_file(): custom_config_df = pd.read_csv( filepath_or_buffer="./tmp/session-data/custom-config.csv", @@ -1338,11 +1359,11 @@ def dicom_deidentifier( rw_obj.define_undefined_clean_dicom_fp(dcm) rw_obj.export_processed_data(dcm=dcm, bbox_img=bbox_img) # type: ignore[arg-type] rw_obj.export_session(session=session) - return session, rw_obj.dicom_pair_fps # type: ignore[return-value] + return session, rw_obj.dicom_pair_fps, rw_obj.raw_dicom_hashes # type: ignore[return-value] @app.post("/submit_button") -async def handle_submit_button_click(user_options: UserOptionsClass) -> list[Any]: +async def handle_submit_button_click(user_options: UserOptionsClass) -> SubmitResponse: user_options = dict(user_options) # type: ignore[assignment] user_options["input_dcm_dp"] = "./tmp/session-data/raw" # type: ignore[index] user_options["output_dcm_dp"] = "./tmp/session-data/clean" # type: ignore[index] @@ -1357,12 +1378,13 @@ async def handle_submit_button_click(user_options: UserOptionsClass) -> list[Any "./tmp/session-data/clean/de-identified-files/session.json", ).open() as file: session = json.load(file) - session, dicom_pair_fps = dicom_deidentifier( # type: ignore[assignment] + session, dicom_data_fps, dcm_hashes = dicom_deidentifier( # type: ignore[assignment] session=session, ) if user_options["annotation"]: # type: ignore[index] prepare_medsam() - return dicom_pair_fps + + return {"dicom_data_fps": dicom_data_fps, "dcm_hashes": dcm_hashes} # type: ignore[return-value] def generate_action_groups() -> None: diff --git a/python/static/script.js b/python/static/script.js index f6faa0c..c4cfcfb 100644 --- a/python/static/script.js +++ b/python/static/script.js @@ -17,7 +17,7 @@ var ctx = OverlayCanvas.getContext('2d'); var ToggleEdit = document.getElementById('ToggleEdit'); var BrushSizeSlider = document.getElementById('BrushSizeSlider'); var BrushSelect = document.getElementById('BrushSelect'); -var LoadDICOM = document.getElementById('ResetDICOM'); +var LoadMask = document.getElementById('ResetMask'); var ExportMasks = document.getElementById('ExportMasks'); var Undo = document.getElementById('Undo'); var Redo = document.getElementById('Redo'); @@ -43,6 +43,7 @@ var current_dicom_data_fp; var OpenSequences = []; var DiffEnabled = false; var dcm_idx_; +var dcm_hash; var isEditing = false; var currentBrush = 'background'; var brushSize = 25; @@ -56,6 +57,7 @@ var BoxStart = null; var BoxEnd = null; var progress_saved = true; var notificationTimeout; +var dcm_hashes; ctx.lineJoin = 'round'; ctx.lineCap = 'round'; const colorMap = { @@ -327,7 +329,8 @@ async function UpdateDICOMInformation(dcm_idx) slider_pending_update = false; LoadingState = true; dcm_idx_ = dcm_idx; - current_dicom_data_fp = await dicom_data_fps[dcm_idx_] + dcm_hash = dcm_hashes[dcm_idx_]; + current_dicom_data_fp = await dicom_data_fps[dcm_idx_]; DisplayRadio.value = "cleaned-display-option"; const conversion_info_response = await fetch ( @@ -400,14 +403,14 @@ nextSlice.addEventListener("click", function () { }); document.addEventListener("DOMContentLoaded", function() { - check_existence_of_clean(); + get_clean_cache(); }); -async function check_existence_of_clean() +async function get_clean_cache() { const dcm_files_response = await fetch ( - '/check_existence_of_clean', + '/get_clean_cache', { method: 'GET' } @@ -569,7 +572,7 @@ async function submit_dicom_processing_request() 'retain_descriptors': retain_descriptors_input_checkbox.checked, 'patient_pseudo_id_prefix': patient_pseudo_id_prefix_input_text.value }; - const dicom_data_fps_response = await fetch + const dicom_data_response = await fetch ( '/submit_button', { @@ -582,7 +585,9 @@ async function submit_dicom_processing_request() } ); - dicom_data_fps = await dicom_data_fps_response.json(); + dicom_data = await dicom_data_response.json(); + dicom_data_fps = dicom_data.dicom_data_fps + dcm_hashes = dicom_data.dcm_hashes if (clean_image.checked) { DisplayRadio.disabled=false; } @@ -835,7 +840,7 @@ async function medsam_estimation(normalizedStart,normalizedEnd) { normalized_start: normalizedStart, normalized_end: normalizedEnd, seg_class: classesMap.indexOf(BrushSelect.value), - inp_idx: dcm_idx_ + dcm_hash: dcm_hash, }; const box_response = await fetch( '/medsam_estimation/', @@ -992,13 +997,13 @@ async function submit_classes(){ BrushSizeSlider.disabled = false; Undo.disabled = false; Redo.disabled = false; - LoadDICOM.disabled = false; + LoadMask.disabled = false; ExportMasks.disabled = false; Add.disabled = true; Remove.disabled = true; ClassText.disabled = true; SubmitClasses.disabled = true; - DisplayRadio.disabled=false; + DisplayRadio.disabled = false; if (classesMap.length !== predefinedClassesMap.length && predefinedClassesMap.length !== 1) { var optionModal = new bootstrap.Modal(document.getElementById('optionModal'), { @@ -1034,6 +1039,16 @@ async function submit_classes(){ }); get_mask_from_file(); } + if (classesMap.length === 1) + { + ToggleEdit.disabled = true; + ToggleMask.disabled = true; + BrushSizeSlider.disabled = true; + Undo.disabled = true; + Redo.disabled = true; + LoadMask.disabled = true; + ExportMasks.disabled = true; + } showNotification("success", "Submitted classes", 1500); } @@ -1148,7 +1163,7 @@ function resetGUIElements() { BrushSizeSlider.disabled = true; Undo.disabled = true; Redo.disabled = true; - LoadDICOM.disabled = true; + LoadMask.disabled = true; ExportMasks.disabled = true; Add.disabled = true; Remove.disabled = true; diff --git a/python/templates/index.html b/python/templates/index.html index 71ea532..3c70036 100644 --- a/python/templates/index.html +++ b/python/templates/index.html @@ -298,7 +298,7 @@