Skip to content

Commit

Permalink
🐛 Fix Errors in the slidegraph Example Notebook (#608)
Browse files Browse the repository at this point in the history
Fixes some issues with the slidegraph notebooks.

1. Updates due to changes in STRtree in recent shapely versions
2. In the 'cell-composition' mode, add the 'filter_coordinates' step so that the mask is considered when generating graph nodes. Also made a small tweak to mask filter so that mask doesnt have to be single channel.
3. Fix the resolution of the plots being wrong when not using pre-generated model
4. Fixes for a couple of issues related to datatypes, maybe they crept in at some point due to numpy or torch version changes.

I have also added a note to explain the last few cells of the inference notebook are for composition features only, as there are only pretrained model weights for that mode.
---------

Co-authored-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com>
  • Loading branch information
measty and shaneahmed committed Jun 23, 2023
1 parent 7369d8d commit e4deac4
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 125 deletions.
33 changes: 22 additions & 11 deletions examples/full-pipelines/slide-graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@
"\n",
"def get_cell_compositions(\n",
" wsi_path: str,\n",
" mask_path: str,\n",
" inst_pred_path: str,\n",
" save_dir: str,\n",
" num_types: int = 6,\n",
Expand All @@ -662,8 +663,6 @@
" inst_boxes = np.array(inst_boxes)\n",
"\n",
" geometries = [shapely_box(*bounds) for bounds in inst_boxes]\n",
" # An auxiliary dictionary to actually query the index within the source list\n",
" index_by_id = {id(geo): idx for idx, geo in enumerate(geometries)}\n",
" spatial_indexer = STRtree(geometries)\n",
"\n",
" # * Generate patch coordinates (in xy format)\n",
Expand All @@ -676,21 +675,30 @@
" stride_shape=stride_shape,\n",
" )\n",
"\n",
" # filter out coords which dont lie in mask\n",
" selected_coord_indices = PatchExtractor.filter_coordinates(\n",
" WSIReader.open(mask_path),\n",
" patch_inputs,\n",
" wsi_shape=wsi_shape,\n",
" min_mask_ratio=0.5,\n",
" )\n",
" patch_inputs = patch_inputs[selected_coord_indices]\n",
"\n",
" bounds_compositions = []\n",
" for bounds in patch_inputs:\n",
" bounds_ = shapely_box(*bounds)\n",
" indices = [\n",
" index_by_id[id(geo)]\n",
" geo\n",
" for geo in spatial_indexer.query(bounds_)\n",
" if bounds_.contains(geo)\n",
" if bounds_.contains(geometries[geo])\n",
" ]\n",
" insts = [inst_pred[v][\"type\"] for v in indices]\n",
" uids, freqs = np.unique(insts, return_counts=True)\n",
" # A bound may not contain all types, hence, to sync\n",
" # the array and placement across all types, we create\n",
" # a holder then fill the count within.\n",
" holder = np.zeros(num_types, dtype=np.int16)\n",
" holder[uids] = freqs\n",
" holder[uids.astype(int)] = freqs\n",
" bounds_compositions.append(holder)\n",
" bounds_compositions = np.array(bounds_compositions)\n",
"\n",
Expand All @@ -706,8 +714,11 @@
" inst_segmentor = NucleusInstanceSegmentor(\n",
" pretrained_model=\"hovernet_fast-pannuke\",\n",
" batch_size=16,\n",
" num_postproc_workers=2,\n",
" num_postproc_workers=4,\n",
" num_loader_workers=4,\n",
" )\n",
" # bigger tile shape for postprocessing performance\n",
" inst_segmentor.ioconfig.tile_shape = (4000, 4000)\n",
" # Injecting customized preprocessing functions,\n",
" # check the document or sample codes below for API\n",
" inst_segmentor.model.preproc_func = preproc_func\n",
Expand Down Expand Up @@ -735,7 +746,7 @@
"\n",
" # TODO: parallelize this later if possible\n",
" for idx, path in enumerate(output_paths):\n",
" get_cell_compositions(wsi_paths[idx], path, save_dir)\n",
" get_cell_compositions(wsi_paths[idx], msk_paths[idx], path, save_dir)\n",
" return output_paths"
]
},
Expand Down Expand Up @@ -1035,7 +1046,7 @@
"outputs": [],
"source": [
"NODE_SIZE = 24\n",
"NODE_RESOLUTION = dict(resolution=0.5, units=\"mpp\")\n",
"NODE_RESOLUTION = dict(resolution=0.25, units=\"mpp\")\n",
"PLOT_RESOLUTION = dict(resolution=4.0, units=\"mpp\")"
]
},
Expand Down Expand Up @@ -1077,7 +1088,7 @@
"plot_resolution = reader.slide_dimensions(**PLOT_RESOLUTION)\n",
"fx = np.array(node_resolution) / np.array(plot_resolution)\n",
"\n",
"node_coordinates = np.array(graph.coords) / fx\n",
"node_coordinates = np.array(graph.coordinates) / fx\n",
"edges = graph.edge_index.T\n",
"\n",
"thumb = reader.slide_thumbnail(**PLOT_RESOLUTION)\n",
Expand Down Expand Up @@ -2458,7 +2469,7 @@
"\n",
"NODE_SIZE = 25\n",
"NUM_NODE_FEATURES = 4\n",
"NODE_RESOLUTION = dict(resolution=0.5, units=\"mpp\")\n",
"NODE_RESOLUTION = dict(resolution=0.25, units=\"mpp\")\n",
"PLOT_RESOLUTION = dict(resolution=4.0, units=\"mpp\")\n",
"\n",
"node_scaler = joblib.load(SCALER_PATH)\n",
Expand Down Expand Up @@ -2503,7 +2514,7 @@
"cmap = plt.get_cmap(\"inferno\")\n",
"graph = graph.to(\"cpu\")\n",
"\n",
"node_coordinates = np.array(graph.coords) / fx\n",
"node_coordinates = np.array(graph.coordinates) / fx\n",
"node_colors = (cmap(np.squeeze(node_activations))[..., :3] * 255).astype(np.uint8)\n",
"edges = graph.edge_index.T\n",
"\n",
Expand Down
262 changes: 151 additions & 111 deletions examples/inference-pipelines/slide-graph.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# torch installation
--extra-index-url https://download.pytorch.org/whl/cu117; sys_platform != "darwin"
--extra-index-url https://download.pytorch.org/whl/cu118; sys_platform != "darwin"
albumentations>=1.3.0
Click>=8.1.3
defusedxml>=0.7.1
Expand Down
2 changes: 1 addition & 1 deletion tiatoolbox/tools/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def build(

return {
"x": feature_centroids,
"edge_index": edge_index,
"edge_index": edge_index.astype(np.int64),
"coordinates": point_centroids,
}

Expand Down
2 changes: 1 addition & 1 deletion tiatoolbox/tools/patchextraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def filter_coordinates(
tissue_mask = mask_reader.img

# Scaling the coordinates_list to the `tissue_mask` array resolution
scale_factors = np.array(tissue_mask.shape[::-1]) / np.array(wsi_shape)
scale_factors = np.array(tissue_mask.shape[1::-1]) / np.array(wsi_shape)
scaled_coords = coordinates_list.copy().astype(np.float32)
scaled_coords[:, [0, 2]] *= scale_factors[0]
scaled_coords[:, [0, 2]] = np.clip(
Expand Down

0 comments on commit e4deac4

Please sign in to comment.