From 8360c1080f7235d9686dd1abb72b18a8f5222ad6 Mon Sep 17 00:00:00 2001 From: zhaoziheng <565295081@qq.com> Date: Wed, 10 Jan 2024 02:21:02 +0800 Subject: [PATCH] support gaussian windows --- inference.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/inference.py b/inference.py index 1c56bf2..688cd4a 100644 --- a/inference.py +++ b/inference.py @@ -11,7 +11,7 @@ from pathlib import Path import nibabel as nib from tqdm import tqdm -# from scipy.ndimage import gaussian_filter +from scipy.ndimage import gaussian_filter from einops import reduce, rearrange, repeat from dataset.inference_dataset import Inference_Dataset, inference_collate_fn @@ -34,7 +34,7 @@ def compute_gaussian(tile_size, sigma_scale: float = 1. / 8, value_scaling_facto return gaussian_importance_map -def inference(model, tokenizer, text_encoder, device, testloader): +def inference(model, tokenizer, text_encoder, device, testloader, gaussian_filter): # inference model.eval() text_encoder.eval() @@ -48,7 +48,7 @@ def inference(model, tokenizer, text_encoder, device, testloader): testloader = tqdm(testloader, disable=False) # gaussian kernel to accumulate predcition - gaussian = np.ones((288, 288, 96)) # compute_gaussian((288, 288, 96)) + windows = compute_gaussian((288, 288, 96)) if gaussian_filter else np.ones((288, 288, 96)) for batch in testloader: # one batch for each sample # data loading @@ -86,9 +86,9 @@ def inference(model, tokenizer, text_encoder, device, testloader): for b in range(len(y1y2_x1x2_z1z2_ls)): y1, y2, x1, x2, z1, z2 = y1y2_x1x2_z1z2_ls[b] - # gaussian accumulation - prediction[n1:n2, y1:y2, x1:x2, z1:z2] += prediction_patch[b, :n2-n1, :y2-y1, :x2-x1, :z2-z1] * gaussian[:y2-y1, :x2-x1, :z2-z1] - accumulation[n1:n2, y1:y2, x1:x2, z1:z2] += gaussian[:y2-y1, :x2-x1, :z2-z1] + # accumulation + prediction[n1:n2, y1:y2, x1:x2, z1:z2] += prediction_patch[b, :n2-n1, :y2-y1, :x2-x1, :z2-z1] * windows[:y2-y1, :x2-x1, :z2-z1] + accumulation[n1:n2, y1:y2, x1:x2, z1:z2] += windows[:y2-y1, :x2-x1, :z2-z1] # avg prediction = prediction / accumulation @@ -147,9 +147,12 @@ def main(args): tokenizer = MyTokenizer(args.tokenizer_path) # choose how to evaluate the checkpoint - inference(model, tokenizer, text_encoder, device, testloader) + inference(model, tokenizer, text_encoder, device, testloader, args.gaussian_filter) if __name__ == '__main__': + def str2bool(v): + return v.lower() in ('true') + parser = argparse.ArgumentParser() parser.add_argument( "--checkpoint", @@ -200,6 +203,11 @@ def main(args): type=str, default='UNET', ) + parser.add_argument( + "--gaussian_filter", + type=str2bool, + default='False', + ) args = parser.parse_args() main(args) \ No newline at end of file