Skip to content

Commit

Permalink
support gaussian windows
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoziheng committed Jan 9, 2024
1 parent 9653d40 commit 8360c10
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pathlib import Path
import nibabel as nib
from tqdm import tqdm
# from scipy.ndimage import gaussian_filter
from scipy.ndimage import gaussian_filter
from einops import reduce, rearrange, repeat

from dataset.inference_dataset import Inference_Dataset, inference_collate_fn
Expand All @@ -34,7 +34,7 @@ def compute_gaussian(tile_size, sigma_scale: float = 1. / 8, value_scaling_facto

return gaussian_importance_map

def inference(model, tokenizer, text_encoder, device, testloader):
def inference(model, tokenizer, text_encoder, device, testloader, gaussian_filter):
# inference
model.eval()
text_encoder.eval()
Expand All @@ -48,7 +48,7 @@ def inference(model, tokenizer, text_encoder, device, testloader):
testloader = tqdm(testloader, disable=False)

# gaussian kernel to accumulate predcition
gaussian = np.ones((288, 288, 96)) # compute_gaussian((288, 288, 96))
windows = compute_gaussian((288, 288, 96)) if gaussian_filter else np.ones((288, 288, 96))

for batch in testloader: # one batch for each sample
# data loading
Expand Down Expand Up @@ -86,9 +86,9 @@ def inference(model, tokenizer, text_encoder, device, testloader):
for b in range(len(y1y2_x1x2_z1z2_ls)):
y1, y2, x1, x2, z1, z2 = y1y2_x1x2_z1z2_ls[b]

# gaussian accumulation
prediction[n1:n2, y1:y2, x1:x2, z1:z2] += prediction_patch[b, :n2-n1, :y2-y1, :x2-x1, :z2-z1] * gaussian[:y2-y1, :x2-x1, :z2-z1]
accumulation[n1:n2, y1:y2, x1:x2, z1:z2] += gaussian[:y2-y1, :x2-x1, :z2-z1]
# accumulation
prediction[n1:n2, y1:y2, x1:x2, z1:z2] += prediction_patch[b, :n2-n1, :y2-y1, :x2-x1, :z2-z1] * windows[:y2-y1, :x2-x1, :z2-z1]
accumulation[n1:n2, y1:y2, x1:x2, z1:z2] += windows[:y2-y1, :x2-x1, :z2-z1]

# avg
prediction = prediction / accumulation
Expand Down Expand Up @@ -147,9 +147,12 @@ def main(args):
tokenizer = MyTokenizer(args.tokenizer_path)

# choose how to evaluate the checkpoint
inference(model, tokenizer, text_encoder, device, testloader)
inference(model, tokenizer, text_encoder, device, testloader, args.gaussian_filter)

if __name__ == '__main__':
def str2bool(v):
return v.lower() in ('true')

parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint",
Expand Down Expand Up @@ -200,6 +203,11 @@ def main(args):
type=str,
default='UNET',
)
parser.add_argument(
"--gaussian_filter",
type=str2bool,
default='False',
)
args = parser.parse_args()

main(args)

0 comments on commit 8360c10

Please sign in to comment.