Skip to content

Commit

Permalink
minor bug fixes and improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
soumickmj committed Feb 14, 2021
1 parent 2a278c9 commit b0c8903
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 119 deletions.
2 changes: 1 addition & 1 deletion Models/unet3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class U_Net(nn.Module):
def __init__(self, in_ch=1, out_ch=1):
super(U_Net, self).__init__()

n1 = 64 #TODO: original paper starts with 64
n1 = 64 #TODO: make params
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024

self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
Expand Down
6 changes: 3 additions & 3 deletions Models/unet3d_DeepSup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class U_Net_DeepSup(nn.Module):
def __init__(self, in_ch=1, out_ch=1):
super(U_Net_DeepSup, self).__init__()

n1 = 64 #TODO: original paper starts with 64
n1 = 64 #TODO: make params
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024

self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
Expand Down Expand Up @@ -220,7 +220,7 @@ class U_Net_DeepSup_level4(nn.Module):
def __init__(self, in_ch=1, out_ch=1):
super(U_Net_DeepSup_level4, self).__init__()

n1 = 64 #TODO: original paper starts with 64
n1 = 64 #TODO: make params
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024

self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
Expand Down Expand Up @@ -338,7 +338,7 @@ class U_Net_DeepSup_level4_wta(nn.Module):
def __init__(self, in_ch=1, out_ch=1):
super(U_Net_DeepSup_level4_wta, self).__init__()

n1 = 64 # TODO: original paper starts with 64
n1 = 64 # TODO: make params
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024

self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
Expand Down
6 changes: 3 additions & 3 deletions Utils/elastic_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ def warp_image(image, displacement, multi=False):
image_size = image.size() #[B, D, H, W]
batch_size = image_size[0]
if multi:
image_size = image_size[1:]#[D, H, W]
image_size = image_size[2:]#[D, H, W]

grid = compute_grid(image_size, dtype=image.dtype, device=image.device)
grid = displacement + grid
grid = torch.cat([grid] * batch_size, dim=0) # batch number of times

# warp image
if multi:
warped_image = F.grid_sample(image.unsqueeze(1), grid) #[B, C, D, H, W], unsqueeze to give channel dimension
warped_image = F.grid_sample(image, grid) #[B, C, D, H, W]
else:
warped_image = F.grid_sample(image.unsqueeze(0).unsqueeze(0), grid) #[B, C, D, H, W], unsqueeze to give batch and channel dimension

Expand Down Expand Up @@ -206,7 +206,7 @@ def parse_max_displacement(
Images: shape of [N,D,H,W] or [N,H,W]
"""
def forward(self, images):
bspline_transform = ParameterizedBsplineTransformation(images[0].size(),
bspline_transform = ParameterizedBsplineTransformation(images.size()[2:], #ignore batch and channel dim
sigma=self.num_control_points,
rnd_grid_params=self.bspline_params,
diffeomorphic=True,
Expand Down
5 changes: 3 additions & 2 deletions Utils/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
MODEL_UNET = 1
MODEL_UNET_DEEPSUP = 2
MODEL_ATTENTION_UNET = 3
MODEL_PROBABILISTIC_UNET = 4

def getModel(model_no):
def getModel(model_no): #Send model params from outside
defaultModel = U_Net() #Default
model_list = {
1: U_Net(),
2: U_Net_DeepSup(),
2: U_Net_DeepSup(),
3: AttU_Net()
}
return model_list.get(model_no, defaultModel)
20 changes: 7 additions & 13 deletions Utils/vessel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
__email__ = "soumick.chatterjee@ovgu.de"
__status__ = "Production"

def write_summary(writer, logger, index, original, reconstructed, focalTverskyLoss, diceLoss, diceScore, iou):
def write_summary(writer, logger, index, original=None, reconstructed=None, focalTverskyLoss=0, diceLoss=0, diceScore=0, iou=0):
"""
Method to write summary to the tensorboard.
index: global_index for the visualisation
Expand All @@ -44,21 +44,18 @@ def write_summary(writer, logger, index, original, reconstructed, focalTverskyLo
writer.add_scalar('DiceScore', diceScore, index)
writer.add_scalar('IOU', iou, index)

writer.add_image('original', original.cpu().data.numpy()[None,:],index)
writer.add_image('reconstructed', reconstructed.cpu().data.numpy()[None,:], index)
writer.add_image('diff', np.moveaxis(create_diff_mask(reconstructed,original,logger), -1, 0), index) #create_diff_mask is of the format HXWXC, but CXHXW is needed
if original is not None and reconstructed is not None:
writer.add_image('original', original.cpu().data.numpy()[None,:],index)
writer.add_image('reconstructed', reconstructed.cpu().data.numpy()[None,:], index)
writer.add_image('diff', np.moveaxis(create_diff_mask(reconstructed,original,logger), -1, 0), index) #create_diff_mask is of the format HXWXC, but CXHXW is needed

def save_model(CHECKPOINT_PATH, state, best_metric = False,filename='checkpoint'):
def save_model(CHECKPOINT_PATH, state, filename='checkpoint'):
"""
Method to save model
"""
print('Saving model...')
if not os.path.exists(CHECKPOINT_PATH):
os.mkdir(CHECKPOINT_PATH)
# if best_metric: #TODO check if its needed
# if not os.path.exists(CHECKPOINT_PATH + 'best_metric/'):
# CHECKPOINT_PATH = CHECKPOINT_PATH + 'best_metric/'
# os.mkdir(CHECKPOINT_PATH)
torch.save(state, CHECKPOINT_PATH + filename + str(state['epoch_type']) + '.pth')


Expand All @@ -82,10 +79,7 @@ def load_model_with_amp(model, CHECKPOINT_PATH, batch_index='best', learning_rat
"""
print('Loading model...')
model.cuda()
try: #TODO dirty fix for now
checkpoint = torch.load(CHECKPOINT_PATH + filename + str(batch_index) + '.pth')
except:
checkpoint = torch.load(CHECKPOINT_PATH + filename + str(batch_index) + '50.pth')
checkpoint = torch.load(CHECKPOINT_PATH + filename + str(batch_index) + '.pth')
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
model.load_state_dict(checkpoint['state_dict'])
Expand Down
25 changes: 19 additions & 6 deletions main_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import argparse

import numpy as np
import random

from Utils.logger import Logger
from Utils.model_manager import getModel
from Utils.vessel_utils import load_model_with_amp, load_model
Expand All @@ -24,6 +27,12 @@
__email__ = "soumick.chatterjee@ovgu.de"
__status__ = "Production"

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(2020)
np.random.seed(2020)
random.seed(2020)

if __name__ == '__main__':

parser = argparse.ArgumentParser()
Expand All @@ -48,7 +57,7 @@
default=True,
help="To train the model")
parser.add_argument('-test',
default=True,
default=False,
help="To test the model")
parser.add_argument('-predict',
default=False,
Expand Down Expand Up @@ -145,11 +154,15 @@
writer_training = SummaryWriter(TENSORBOARD_PATH_TRAINING)
writer_validating = SummaryWriter(TENSORBOARD_PATH_VALIDATION)

pipeline = Pipeline(model=model, optimizer=optimizer, logger=logger, with_apex=args.apex, num_epochs=args.num_epochs,
dir_path=DATASET_FOLDER, checkpoint_path=CHECKPOINT_PATH, deform=args.deform,
writer_training=writer_training, writer_validating=writer_validating,
stride_depth=args.stride_depth, stride_length=args.stride_length, stride_width=args.stride_width,
predict_only=(not args.train) and (not args.test))
# pipeline = Pipeline(model=model, optimizer=optimizer, logger=logger, with_apex=args.apex, num_epochs=args.num_epochs,
# dir_path=DATASET_FOLDER, checkpoint_path=CHECKPOINT_PATH, deform=args.deform,
# writer_training=writer_training, writer_validating=writer_validating,
# stride_depth=args.stride_depth, stride_length=args.stride_length, stride_width=args.stride_width,
# predict_only=(not args.train) and (not args.test))

pipeline = Pipeline(cmd_args=args, model=model, optimizer=optimizer, logger=logger,
dir_path=DATASET_FOLDER, checkpoint_path=CHECKPOINT_PATH,
writer_training=writer_training, writer_validating=writer_validating)

try:

Expand Down
Loading

0 comments on commit b0c8903

Please sign in to comment.