Skip to content

Commit

Permalink
[Improvement] Image embeddings persistence among page refreshes
Browse files Browse the repository at this point in the history
  • Loading branch information
fl0wxr committed Apr 17, 2024
1 parent f38c4da commit 1423470
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 59 deletions.
116 changes: 69 additions & 47 deletions python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import socket
import subprocess
import sys
import time
import unittest
from functools import lru_cache
from io import BytesIO
Expand Down Expand Up @@ -86,7 +85,7 @@ class BoxData(BaseModel):
normalized_start: dict[str, float]
normalized_end: dict[str, float]
seg_class: int
inp_idx: int
dcm_hash: str


class BoxDataResponse(BaseModel):
Expand Down Expand Up @@ -116,6 +115,11 @@ class UploadFilesResponse(BaseModel):
skip_deidentification: bool


class SubmitResponse(BaseModel):
dicom_data_fps: list[Any]
dcm_hashes: list[str]


def dcm2dictmetadata(ds: pydicom.dataset.Dataset) -> dict[str, dict[str, str]]:
ds_metadata_dict = {}
for ds_attr in ds:
Expand Down Expand Up @@ -189,7 +193,7 @@ def test_submit_button(self: TestEndpoints) -> None:
)
hasher = hashlib.sha256()
block_size = 65536
with Path(json_response[0][1]).open("rb") as file:
with Path(json_response["dicom_data_fps"][0][1]).open("rb") as file:
buf = file.read(block_size)
while len(buf) > 0:
hasher.update(buf)
Expand All @@ -202,8 +206,8 @@ def test_submit_button(self: TestEndpoints) -> None:
)


@app.get("/check_existence_of_clean")
async def check_existence_of_clean() -> UploadFilesResponse:
@app.get("/get_clean_cache")
async def get_clean_cache() -> UploadFilesResponse:
session_fp = "./tmp/session-data/clean/de-identified-files/session.json"
proper_dicom_paths = sorted(
glob.glob( # noqa: PTH207
Expand Down Expand Up @@ -792,7 +796,7 @@ async def medsam_estimation(boxdata: BoxData) -> BoxDataResponse:
start = boxdata.normalized_start
end = boxdata.normalized_end
seg_class = boxdata.seg_class
inp_idx = boxdata.inp_idx
dcm_hash = boxdata.dcm_hash
bbox = np.array(
[
min(start["x"], end["x"]),
Expand All @@ -802,54 +806,71 @@ async def medsam_estimation(boxdata: BoxData) -> BoxDataResponse:
],
)
box_256 = bbox[None, :] * 256
time.time()
medsam_model = load_model()
temp_dir = Path("./tmp/session-data/embed")
embedding = torch.load(temp_dir / f"embed_{inp_idx}.pt")
hs = np.load(temp_dir / "Hs.npy")
ws = np.load(temp_dir / "Ws.npy")
embed_dp = Path("./tmp/session-data/embed")
embedding = torch.load(embed_dp / f"{dcm_hash}.pt")
with open(file="./tmp/session-data/shape.json") as f: # noqa: PTH123, ASYNC101
h, w = json.load(f)[dcm_hash]
medsam_seg = medsam_inference(
medsam_model,
embedding,
box_256,
(256, 256),
(hs[inp_idx], ws[inp_idx]),
(h, w),
)
medsam_seg = (seg_class * medsam_seg).astype(np.uint8)
return BoxDataResponse(
mask=base64.b64encode(medsam_seg).decode("utf-8"),
dimensions=[int(ws[inp_idx]), int(hs[inp_idx])],
dimensions=[w, h],
)


def prepare_medsam() -> None:
def initialize_image_embeddings(dcm_hash, img) -> None: # type: ignore[no-untyped-def] # noqa: ANN001
img_embeding_fp = embed_dp / f"{dcm_hash}.pt"
if not os.path.exists(img_embeding_fp):
two_d = 2
img_3c = (
np.repeat(img[:, :, None], 3, axis=-1) if len(img.shape) == two_d else img
)
img_256 = cv2.resize(src=img_3c, dsize=(256, 256)).astype(np.float32)
img_256 = (img_256 - img_256.min()) / np.clip(
img_256.max() - img_256.min(),
a_min=1e-8,
a_max=None,
)
img_256_tensor = torch.tensor(img_256).float().permute(2, 0, 1).unsqueeze(0)
with torch.no_grad():
embedding = medsam_model.image_encoder(img_256_tensor)
torch.save(embedding, img_embeding_fp)

medsam_model = load_model()
raw_fp = Path("./tmp/session-data/raw")
dcm_fps = sorted(raw_fp.glob("*"))
temp_dir = Path("./tmp/session-data/embed")
hs, ws = [], []
for idx, dcm_fp in enumerate(dcm_fps):
raw_dp = Path("./tmp/session-data/raw")
embed_dp = Path("./tmp/session-data/embed")
clean_dp = Path("./tmp/session-data/clean")
dcm_fps = list(raw_dp.glob("**/*.dcm")) + list(clean_dp.glob("**/*.dcm"))
embed_fps = glob.glob(os.path.join(str(embed_dp), "**/*.pt")) # noqa: PTH207, PTH118
embed_hashes = [embed_fp.split("/")[-1].split(".")[0] for embed_fp in embed_fps]
shapes = {}
for dcm_fp in dcm_fps:
if "raw" in str(dcm_fp):
with open(file=dcm_fp, mode="rb") as f: # noqa: PTH123
dcm_hash = hashlib.sha256(f.read()).hexdigest()
else:
dcm_hash = str(dcm_fp).split("/")[-1].split(".")[0]
img = pydicom.dcmread(dcm_fp).pixel_array
two_d = 2
img_3c = (
np.repeat(img[:, :, None], 3, axis=-1) if len(img.shape) == two_d else img
)
h, w, _ = img_3c.shape
hs.append(h)
ws.append(w)
img_256 = cv2.resize(src=img_3c, dsize=(256, 256)).astype(np.float32)
img_256 = (img_256 - img_256.min()) / np.clip(
img_256.max() - img_256.min(),
a_min=1e-8,
a_max=None,
)
img_256_tensor = torch.tensor(img_256).float().permute(2, 0, 1).unsqueeze(0)
with torch.no_grad():
embedding = medsam_model.image_encoder(img_256_tensor)
torch.save(embedding, temp_dir / f"embed_{idx}.pt")

np.save(temp_dir / "Hs.npy", np.array(hs))
np.save(temp_dir / "Ws.npy", np.array(ws))
if len(img.shape) == 2: # noqa: PLR2004
h, w = img.shape
else:
h, w = img.shape[:-1]
shapes[dcm_hash] = [h, w]
if dcm_hash in embed_hashes:
continue
else: # noqa: RET507
initialize_image_embeddings(dcm_hash, img)
embed_hashes.append(dcm_hash)
with open(file=Path("./tmp/session-data") / "shape.json", mode="w") as f: # type: ignore[assignment] # noqa: PTH123
json.dump(shapes, f) # type: ignore[arg-type]


def deidentification_attributes(
Expand Down Expand Up @@ -1210,7 +1231,7 @@ def __next__(
if self.n_dicom_files - 1 >= self.DICOM_IDX:
self.raw_dicom_fp, self.clean_dicom_fp = self.dicom_pair_fps[self.DICOM_IDX]
self.deintentify = self.pending_deidentification[self.DICOM_IDX]
self.raw_dicom_hash = self.raw_dicom_hashes[self.DICOM_IDX]
self.dcm_hash = self.raw_dicom_hashes[self.DICOM_IDX]
return True
return False

Expand All @@ -1226,7 +1247,7 @@ def define_undefined_clean_dicom_fp(self, clean_dcm) -> None: # type: ignore[no
)
if not Path(clean_dicom_dp).exists():
Path(clean_dicom_dp).mkdir(parents=True)
clean_dicom_fp = os.path.join(clean_dicom_dp, self.raw_dicom_hash + ".dcm") # noqa: PTH118
clean_dicom_fp = os.path.join(clean_dicom_dp, self.dcm_hash + ".dcm") # noqa: PTH118
self.dicom_pair_fps[self.DICOM_IDX][1] = self.clean_dicom_fp = (
clean_dicom_fp
)
Expand All @@ -1237,9 +1258,9 @@ def export_processed_data(
bbox_img: NDArray[Any],
) -> None:
if bbox_img is not None:
bbox_img_fp = os.path.join(self.out_dp, self.raw_dicom_hash + "_bbox.png") # noqa: PTH118
bbox_img_fp = os.path.join(self.out_dp, self.dcm_hash + "_bbox.png") # noqa: PTH118
Image.fromarray(bbox_img).save(bbox_img_fp) # type: ignore[no-untyped-call]
cache_bbox_img(self.raw_dicom_hash)
cache_bbox_img(self.dcm_hash)
dcm.save_as(self.clean_dicom_fp)

def export_session(
Expand All @@ -1253,7 +1274,7 @@ def export_session(

def dicom_deidentifier(
session: dict, # type: ignore[type-arg]
) -> tuple[dict[str, dict[str, str]], list[tuple[str]]]:
) -> tuple[dict[str, dict[str, str]], list[tuple[str]], list[str]]:
if Path("./tmp/session-data/custom-config.csv").is_file():
custom_config_df = pd.read_csv(
filepath_or_buffer="./tmp/session-data/custom-config.csv",
Expand Down Expand Up @@ -1338,11 +1359,11 @@ def dicom_deidentifier(
rw_obj.define_undefined_clean_dicom_fp(dcm)
rw_obj.export_processed_data(dcm=dcm, bbox_img=bbox_img) # type: ignore[arg-type]
rw_obj.export_session(session=session)
return session, rw_obj.dicom_pair_fps # type: ignore[return-value]
return session, rw_obj.dicom_pair_fps, rw_obj.raw_dicom_hashes # type: ignore[return-value]


@app.post("/submit_button")
async def handle_submit_button_click(user_options: UserOptionsClass) -> list[Any]:
async def handle_submit_button_click(user_options: UserOptionsClass) -> SubmitResponse:
user_options = dict(user_options) # type: ignore[assignment]
user_options["input_dcm_dp"] = "./tmp/session-data/raw" # type: ignore[index]
user_options["output_dcm_dp"] = "./tmp/session-data/clean" # type: ignore[index]
Expand All @@ -1357,12 +1378,13 @@ async def handle_submit_button_click(user_options: UserOptionsClass) -> list[Any
"./tmp/session-data/clean/de-identified-files/session.json",
).open() as file:
session = json.load(file)
session, dicom_pair_fps = dicom_deidentifier( # type: ignore[assignment]
session, dicom_data_fps, dcm_hashes = dicom_deidentifier( # type: ignore[assignment]
session=session,
)
if user_options["annotation"]: # type: ignore[index]
prepare_medsam()
return dicom_pair_fps

return {"dicom_data_fps": dicom_data_fps, "dcm_hashes": dcm_hashes} # type: ignore[return-value]


def generate_action_groups() -> None:
Expand Down
37 changes: 26 additions & 11 deletions python/static/script.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var ctx = OverlayCanvas.getContext('2d');
var ToggleEdit = document.getElementById('ToggleEdit');
var BrushSizeSlider = document.getElementById('BrushSizeSlider');
var BrushSelect = document.getElementById('BrushSelect');
var LoadDICOM = document.getElementById('ResetDICOM');
var LoadMask = document.getElementById('ResetMask');
var ExportMasks = document.getElementById('ExportMasks');
var Undo = document.getElementById('Undo');
var Redo = document.getElementById('Redo');
Expand All @@ -43,6 +43,7 @@ var current_dicom_data_fp;
var OpenSequences = [];
var DiffEnabled = false;
var dcm_idx_;
var dcm_hash;
var isEditing = false;
var currentBrush = 'background';
var brushSize = 25;
Expand All @@ -56,6 +57,7 @@ var BoxStart = null;
var BoxEnd = null;
var progress_saved = true;
var notificationTimeout;
var dcm_hashes;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
const colorMap = {
Expand Down Expand Up @@ -327,7 +329,8 @@ async function UpdateDICOMInformation(dcm_idx)
slider_pending_update = false;
LoadingState = true;
dcm_idx_ = dcm_idx;
current_dicom_data_fp = await dicom_data_fps[dcm_idx_]
dcm_hash = dcm_hashes[dcm_idx_];
current_dicom_data_fp = await dicom_data_fps[dcm_idx_];
DisplayRadio.value = "cleaned-display-option";
const conversion_info_response = await fetch
(
Expand Down Expand Up @@ -400,14 +403,14 @@ nextSlice.addEventListener("click", function () {
});

document.addEventListener("DOMContentLoaded", function() {
check_existence_of_clean();
get_clean_cache();
});

async function check_existence_of_clean()
async function get_clean_cache()
{
const dcm_files_response = await fetch
(
'/check_existence_of_clean',
'/get_clean_cache',
{
method: 'GET'
}
Expand Down Expand Up @@ -569,7 +572,7 @@ async function submit_dicom_processing_request()
'retain_descriptors': retain_descriptors_input_checkbox.checked,
'patient_pseudo_id_prefix': patient_pseudo_id_prefix_input_text.value
};
const dicom_data_fps_response = await fetch
const dicom_data_response = await fetch
(
'/submit_button',
{
Expand All @@ -582,7 +585,9 @@ async function submit_dicom_processing_request()
}
);

dicom_data_fps = await dicom_data_fps_response.json();
dicom_data = await dicom_data_response.json();
dicom_data_fps = dicom_data.dicom_data_fps
dcm_hashes = dicom_data.dcm_hashes
if (clean_image.checked) {
DisplayRadio.disabled=false;
}
Expand Down Expand Up @@ -835,7 +840,7 @@ async function medsam_estimation(normalizedStart,normalizedEnd) {
normalized_start: normalizedStart,
normalized_end: normalizedEnd,
seg_class: classesMap.indexOf(BrushSelect.value),
inp_idx: dcm_idx_
dcm_hash: dcm_hash,
};
const box_response = await fetch(
'/medsam_estimation/',
Expand Down Expand Up @@ -992,13 +997,13 @@ async function submit_classes(){
BrushSizeSlider.disabled = false;
Undo.disabled = false;
Redo.disabled = false;
LoadDICOM.disabled = false;
LoadMask.disabled = false;
ExportMasks.disabled = false;
Add.disabled = true;
Remove.disabled = true;
ClassText.disabled = true;
SubmitClasses.disabled = true;
DisplayRadio.disabled=false;
DisplayRadio.disabled = false;
if (classesMap.length !== predefinedClassesMap.length && predefinedClassesMap.length !== 1)
{
var optionModal = new bootstrap.Modal(document.getElementById('optionModal'), {
Expand Down Expand Up @@ -1034,6 +1039,16 @@ async function submit_classes(){
});
get_mask_from_file();
}
if (classesMap.length === 1)
{
ToggleEdit.disabled = true;
ToggleMask.disabled = true;
BrushSizeSlider.disabled = true;
Undo.disabled = true;
Redo.disabled = true;
LoadMask.disabled = true;
ExportMasks.disabled = true;
}
showNotification("success", "Submitted classes", 1500);
}

Expand Down Expand Up @@ -1148,7 +1163,7 @@ function resetGUIElements() {
BrushSizeSlider.disabled = true;
Undo.disabled = true;
Redo.disabled = true;
LoadDICOM.disabled = true;
LoadMask.disabled = true;
ExportMasks.disabled = true;
Add.disabled = true;
Remove.disabled = true;
Expand Down
2 changes: 1 addition & 1 deletion python/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@
<i class="bi bi-arrow-90deg-right"></i>
</button>
<button type="button"
id="ResetDICOM"
id="ResetMask"
class="btn btn-primary"
onclick="get_mask_from_file(dcm_idx_)"
title="Discard changes"
Expand Down

0 comments on commit 1423470

Please sign in to comment.