Skip to content

Commit

Permalink
patch inference (#87)
Browse files Browse the repository at this point in the history
* testing

* add functions/change torch to numpy

* same as last time

* finished patch_inference

* added parameter '-p'

added parameter '-p'

* remove test files

* added patch_inference_3d_lite

added a version of patch inference that uses less ram

---------

Co-authored-by: JiaShow <p269511@gmail.com>
  • Loading branch information
Norman960122 and pengushow authored Aug 28, 2024
1 parent 41969d8 commit 7dbbd15
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 11 deletions.
13 changes: 7 additions & 6 deletions tigerbx/bx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
elif __file__:
application_path = os.path.dirname(os.path.abspath(__file__))

def produce_mask(model, f, GPU=False, QC=False, brainmask_nib=None, tbet111=None):
def produce_mask(model, f, GPU=False, QC=False, brainmask_nib=None, tbet111=None, patch=False):
if not isinstance(model, list):
model = [model]
# for multi-model ensemble
Expand All @@ -38,7 +38,7 @@ def produce_mask(model, f, GPU=False, QC=False, brainmask_nib=None, tbet111=None


mask_nib_resp, prob_resp = lib_bx.run(
model_ff_list, input_nib_resp, GPU=GPU)
model_ff_list, input_nib_resp, GPU=GPU, patch=patch)

mask_nib = resample_to_img(
mask_nib_resp, input_nib, interpolation="nearest")
Expand Down Expand Up @@ -128,6 +128,7 @@ def main():
parser.add_argument('-r', '--registration', action='store_true', help='Registering images to template')
parser.add_argument('-T', '--template', type=str, help='The template filename(default is MNI152)')
parser.add_argument('-R', '--rigid', action='store_true', help='Rigid transforms images to template')
parser.add_argument('-p', '--patch', action='store_true', help='patch inference')
parser.add_argument('--model', default=None, type=str, help='Specifying the model name')
parser.add_argument('--clean_onnx', action='store_true', help='Clean onnx models')
parser.add_argument('--encode', action='store_true', help='Encoding a brain volume to its latent')
Expand All @@ -142,7 +143,6 @@ def main():


def run(argstring, input=None, output=None, model=None, template=None):

from argparse import Namespace
args = Namespace()
if not isinstance(input, list):
Expand Down Expand Up @@ -175,6 +175,7 @@ def run(argstring, input=None, output=None, model=None, template=None):
args.affine = 'A' in argstring
args.registration = 'r' in argstring
args.rigid = 'R' in argstring
args.patch = 'p' in argstring
args.template = template
return run_args(args)

Expand All @@ -187,7 +188,7 @@ def run_args(args):
run_d['wmh'], run_d['bam'], run_d['tumor'], run_d['cgw'],
run_d['syn'], run_d['affine'], run_d['registration'],
run_d['rigid'], run_d['template'], run_d['encode'],
run_d['decode']]:
run_d['decode'], run_d['patch']]:
run_d['bet'] = True
# Producing extracted brain by default

Expand Down Expand Up @@ -229,7 +230,7 @@ def run_args(args):
omodel['rigid'] = 'mprage_rigid_v001_train.onnx'
omodel['encode'] = 'mprage_encode_v1.onnx'
omodel['decode'] = 'mprage_decode_v1.onnx'

if run_d['encode'] or run_d['decode']:
print('#Autoencoding weights converted from')
print('#Pinaya, Walter HL, et al. Brain imaging generation with latent diffusion models.')
Expand Down Expand Up @@ -349,7 +350,7 @@ def run_args(args):
for seg_str in ['aseg', 'dgm', 'dkt', 'wmp', 'wmh', 'tumor', 'syn']:
if run_d[seg_str]:
result_nib = produce_mask(omodel[seg_str], f, GPU=args.gpu,
brainmask_nib=tbetmask_nib, tbet111=tbet_seg)
brainmask_nib=tbetmask_nib, tbet111=tbet_seg, patch=run_d['patch'])
fn = save_nib(result_nib, ftemplate, seg_str)
result_dict[seg_str] = result_nib
result_filedict[seg_str] = fn
Expand Down
7 changes: 5 additions & 2 deletions tigerbx/lib_bx.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def logit_to_prob(logits, seg_mode):
prob = softmax(logits, axis=0)
return prob

def run(model_ff_list, input_nib, GPU):
def run(model_ff_list, input_nib, GPU, patch=False):

if not isinstance(model_ff_list, list):
model_ff_list = [model_ff_list]
Expand All @@ -130,7 +130,10 @@ def run(model_ff_list, input_nib, GPU):
count = 0
for model_ff in model_ff_list:
count += 1
logits = lib_tool.predict(model_ff, image, GPU)[0, ...]
if patch:
logits = lib_tool.predict(model_ff, image, GPU, mode='patch')[0, ...]
else:
logits = lib_tool.predict(model_ff, image, GPU)[0, ...]
prob += logit_to_prob(logits, seg_mode)
prob = prob/count # average the prob

Expand Down
119 changes: 116 additions & 3 deletions tigerbx/lib_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from os.path import isfile, join
from tigerbx import lib_bx
from nilearn.image import resample_img


from typing import Union, Tuple, List
from scipy.ndimage import gaussian_filter
warnings.filterwarnings("ignore", category=UserWarning)
nib.Nifti1Header.quaternion_threshold = -100

Expand Down Expand Up @@ -273,9 +273,122 @@ def predict(model, data, GPU, mode=None):
if mode == 'decode':
result = session.run(None, {session.get_inputs()[0].name: data.astype(data_type)}, )
return result[0]


if mode == 'patch':
logits = patch_inference_3d_lite(session, data.astype(data_type), patch_size = (160,)*3, gaussian = True)
# print(data.shape)
# logits = session.run(None, {session.get_inputs()[0].name: data.astype(data_type)}, )[0]
# print('logits type', type(logits))

return logits

return session.run(None, {session.get_inputs()[0].name: data.astype(data_type)}, )[0]
def patch_inference_3d_lite(session,
vol_d: np.ndarray,
patch_size : Tuple[int, ...] = (128,)*3,
tile_step_size: float = 0.5,
gaussian = False ):
patches, point_list = img_to_patches(vol_d, patch_size, tile_step_size)#patches.shape = (patch_num, 1, 1, 128, 128, 128)
gaussian_map = compute_gaussian(patch_size)
patch_logits_shape = session.run(None, {session.get_inputs()[0].name: patches[0]}, )[0].shape
prob_tensor = np.zeros(((patch_logits_shape[1],) + vol_d.shape[-3:]))
for patch, p in zip(patches, point_list):
logits = session.run(None, {session.get_inputs()[0].name: patch}, )[0]#logits.shape = (1, c, 128, 128, 128)
if gaussian:
output_patch = logits.squeeze(0)*gaussian_map
none_zero_mask1 = prob_tensor[:, p[0] : p[0]+patch_size[0], p[1] : p[ 1]+patch_size[1], p[2] : p[2]+patch_size[2]]!= 0
none_zero_mask2 = output_patch != 0
none_zero_num = np.clip(none_zero_mask1 + none_zero_mask2, a_min=1, a_max=None)
prob_tensor[: , p[0] : p[0]+patch_size[0], p[1] : p[ 1]+patch_size[1], p[2] : p[2]+patch_size[2]] += output_patch
prob_tensor[: , p[0] : p[0]+patch_size[0], p[1] : p[ 1]+patch_size[1], p[2] : p[2]+patch_size[2]] /= none_zero_num
return prob_tensor[np.newaxis, :]



def patch_inference_3d(session,
vol_d: np.ndarray,
patch_size : Tuple[int, ...] = (128,)*3,
tile_step_size: float = 0.5,
gaussian = False ):
patches, point_list = img_to_patches(vol_d, patch_size, tile_step_size)#patches.shape = (patch_num, 1, 1, 128, 128, 128)
output_patch_list = []
for patch in patches:
logits = session.run(None, {session.get_inputs()[0].name: patch}, )[0]#logits.shape = (1, 1, 128, 128, 128)
output_patch_list.append(logits.squeeze(0))
output_patches = np.concatenate([s[np.newaxis, ...] for s in output_patch_list], axis=0)#shape = (patch_num, 1, 128, 128, 128)
if gaussian:
gaussian_map = compute_gaussian(patch_size)
output_patches = output_patches*gaussian_map
# print(output_patches.shape) # (patch_num, channel, w, h, d)
mean_prob = patches_to_img(output_patches, vol_d.shape[-3:], point_list)
return mean_prob
def compute_steps_for_sliding_window(image_size: Tuple[int, ...],
tile_size: Tuple[int, ...],
tile_step_size: float) -> List[List[int]]:
assert [i >= j for i, j in zip(image_size, tile_size)], "image size must be as large or larger than patch_size"
assert 0 < tile_step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1'

# our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of
# 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46
target_step_sizes_in_voxels = [i * tile_step_size for i in tile_size]

num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, tile_size)]

steps = []
for dim in range(len(tile_size)):
# the highest step value for this dimension is
max_step_value = image_size[dim] - tile_size[dim]
if num_steps[dim] > 1:
actual_step_size = max_step_value / (num_steps[dim] - 1)
else:
actual_step_size = 99999999999 # does not matter because there is only one step at 0

steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]

steps.append(steps_here)
return steps


def compute_gaussian(tile_size: Union[Tuple[int, ...], List[int]],
sigma_scale: float = 1. / 8,
value_scaling_factor: float = 1,
dtype=np.float16) -> np.ndarray:
tmp = np.zeros(tile_size)
center_coords = [i // 2 for i in tile_size]
sigmas = [i * sigma_scale for i in tile_size]
tmp[tuple(center_coords)] = 1
gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)

gaussian_importance_map /= (np.max(gaussian_importance_map) / value_scaling_factor)
gaussian_importance_map = gaussian_importance_map.astype(dtype)
# gaussian_importance_map cannot be 0, otherwise we may end up with nans!
mask = gaussian_importance_map == 0
gaussian_importance_map[mask] = np.min(gaussian_importance_map[~mask])
return gaussian_importance_map

def img_to_patches(vol_d: np.ndarray, patch_size: Tuple[int, ...], tile_step_size: float):
steps = compute_steps_for_sliding_window(vol_d.shape[-3:], patch_size, tile_step_size)
slice_list = []
point_list = [[i, j, k] for i in steps[0] for j in steps[1] for k in steps[2]]
for p in point_list:
slice_input = vol_d[:, :, p[0] : p[0]+patch_size[0], p[1] : p[1]+patch_size[1], p[2] : p[2]+patch_size[2]]
slice_list.append(slice_input)
return np.concatenate([s[np.newaxis, ...] for s in slice_list], axis=0), point_list

def patches_to_img(patches: np.ndarray, vol_d_size: Tuple[int, ...], point_list: List[List[int]]):
'''
patches shape = (patch_num, channel, w, h, d)
'''
patch_size = patches.shape[-3:]
prob_tensor = np.zeros(((patches.shape[1],) + vol_d_size))

for patch_dim, p in zip(range(patches.shape[0]), point_list):
none_zero_mask1 = prob_tensor[:, p[0] : p[0]+patch_size[0], p[1] : p[ 1]+patch_size[1], p[2] : p[2]+patch_size[2]]!= 0
none_zero_mask2 = patches[patch_dim, : ,...]!= 0
none_zero_num = np.clip(none_zero_mask1 + none_zero_mask2, a_min=1, a_max=None)
prob_tensor[: , p[0] : p[0]+patch_size[0], p[1] : p[ 1]+patch_size[1], p[2] : p[2]+patch_size[2]] += patches[patch_dim, : ,...]
prob_tensor[: , p[0] : p[0]+patch_size[0], p[1] : p[ 1]+patch_size[1], p[2] : p[2]+patch_size[2]] /= none_zero_num
return prob_tensor[np.newaxis, :]

def clean_onnx():
import glob
Expand Down

0 comments on commit 7dbbd15

Please sign in to comment.