From 304b8491925859692ed55661903a2931eef47c6e Mon Sep 17 00:00:00 2001 From: perukas Date: Thu, 29 Aug 2019 17:44:47 +0300 Subject: [PATCH] init code commit --- .gitignore | 4 + dataset360N.py | 111 ++++++++++++++++++ file_utils.py | 33 ++++++ losses.py | 118 +++++++++++++++++++ metrics.py | 2 + model.py | 221 +++++++++++++++++++++++++++++++++++ nn_utils.py | 100 ++++++++++++++++ settings_files/continue.json | 0 settings_files/test.json | 10 ++ settings_files/train.json | 29 +++++ test.py | 92 +++++++++++++++ train.py | 184 +++++++++++++++++++++++++++++ 12 files changed, 904 insertions(+) create mode 100644 dataset360N.py create mode 100644 file_utils.py create mode 100644 losses.py create mode 100644 metrics.py create mode 100644 model.py create mode 100644 nn_utils.py create mode 100644 settings_files/continue.json create mode 100644 settings_files/test.json create mode 100644 settings_files/train.json create mode 100644 test.py create mode 100644 train.py diff --git a/.gitignore b/.gitignore index 894a44c..7906d5b 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,10 @@ __pycache__/ # C extensions *.so +# vs-code py cache folder +__pychache__ +.vscode + # Distribution / packaging .Python build/ diff --git a/dataset360N.py b/dataset360N.py new file mode 100644 index 0000000..9587340 --- /dev/null +++ b/dataset360N.py @@ -0,0 +1,111 @@ +import os +import torch +import numpy as np +import cv2 +from torch.utils.data import Dataset +from torchvision import transforms, utils +import torch.nn.functional as F +import PIL.Image as Image +import random + + +class Dataset360N(Dataset): + def __init__(self, filenames_filepath, delimiter, input_shape): + self.length = 0 + self.height = input_shape[0] + self.width = input_shape[1] + self.filenames_filepath = filenames_filepath + self.delim = delimiter + self.data_paths = {} + self.init_data_dict() + self.gather_filepaths() + + def init_data_dict(self): + self.data_paths = { + "rgb": [], + "surface": [] + } + + def gather_filepaths(self): + fd = open(self.filenames_filepath, 'r') + lines = fd.readlines() + for line in lines: + splits = line.split(self.delim) + self.data_paths["rgb"].append(splits[0]) + # TODO: check for filenames files format + self.data_paths["surface"].append(splits[6]) + fd.close() + assert len(self.data_paths["rgb"]) == len(self.data_paths["surface"]) + self.length = len(self.data_paths["rgb"]) + + def load_rgb(self, filepath): + if not os.path.exists(filepath): + print("\tGiven filepath <{}> does not exist".format(filepath)) + return np.zeros((self.height, self.width, 3), dtype = np.float32) + rgb_np = cv2.imread(filepath, cv2.IMREAD_ANYCOLOR) + return rgb_np + + def load_float(self, filepath): + if not os.path.exists(filepath): + print("\tGiven filepath <{}> does not exist".format(filepath)) + return np.zeros((self.height, self.width, 3), dtype = np.float32) + surface_np = cv2.imread(filepath, cv2.IMREAD_UNCHANGED) + # Creates mask for invalid values + surface_np[np.isnan(surface_np)] = 0.0 + mask_np = np.ones_like(surface_np) + mask_np[np.sum(surface_np, 2) == 0.0] = 0.0 + return surface_np, mask_np + + def clean_normal(self, normal): + # check if normals are close to the dominant + # coord system normals + shape = normal.shape + vecs = [ + [1, 0, 0], + [0, 1, 0], + [0, 0 ,1], + [-1, 0, 0], + [0, -1, 0], + [0, 0, -1] + ] + for vec in vecs: + vec_mat = np.asarray(vec, dtype = np.float32) + vec_mat = np.expand_dims(vec_mat, 0) + vec_mat = np.expand_dims(vec_mat, 1) + vec_mat = vec_mat.repeat(shape[0], 0) + vec_mat = vec_mat.repeat(shape[1], 1) + inds = np.isclose(normal, vec_mat, 0.0001, 0.1) + inds = inds[:, :, 0] & inds[:, :, 1] & inds[:, :, 2] + normal[inds, 0] = vec[0] + normal[inds, 1] = vec[1] + normal[inds, 2] = vec[2] + return normal + + def make_tensor(self, np_array): + np_array = np_array.transpose(2, 0, 1) + tensor = torch.from_numpy(np_array) + return torch.as_tensor(tensor, dtype = torch.float32) + + def load_item(self, idx): + item = { } + if (idx >= self.length): + print("Index out of range.") + else: + rgb_np = self.load_rgb(self.data_paths["rgb"][idx]) + surface_np, mask_np = self.load_float(self.data_paths["surface"][idx]) + surface_np = self.clean_normal(surface_np) + rgb = self.make_tensor(rgb_np) + surface = self.make_tensor(surface_np) + surface = F.normalize(surface, p = 2, dim = 1) + mask = self.make_tensor(mask_np) + item['input_rgb'] = rgb + item['target_surface'] = surface + item['mask'] = mask + item['filename'] = os.path.basename(self.data_paths["surface"][idx]) + return item + + def __len__(self): + return self.length + + def __getitem__(self, idx): + return self.load_item(idx) \ No newline at end of file diff --git a/file_utils.py b/file_utils.py new file mode 100644 index 0000000..ad9c716 --- /dev/null +++ b/file_utils.py @@ -0,0 +1,33 @@ +import os +import json +import torch + +''' + Reads configuration file + \param + filepath the absolute path to the configuration file + \return + settings_map dictionary with the configuration settings +''' +def read_configuration_file(filepath): + print("Reading configuration file...") + settings = {} + if os.path.exists(filepath): + with open(filepath, 'r') as fd: + settings = json.load(fd) + assert settings['session'], print("Failed to read configuration file. No session settings.") + # assert settings['session']['optimizer'], print("Failed to read configuration file. No optimizer settings.") + return settings + +def save_state(directory, session_name, model, optimizer, epoch, global_iters): + if os.path.isfile(directory): + directory = os.path.abspath(os.path.dirname(directory)) + model_state_dict = model.state_dict() + optim_state_dict = optimizer.state_dict() + model_filename = session_name + "_model_e_{}_b_{}.chkp".format(epoch, global_iters) + optim_filename = session_name + "_optim_e_{}_b_{}.chkp".format(epoch, global_iters) + model_filepath = os.path.join(directory, model_filename) + optim_filepath = os.path.join(directory, optim_filename) + torch.save(model_state_dict, model_filepath) + torch.save(optim_state_dict, optim_filepath) + diff --git a/losses.py b/losses.py new file mode 100644 index 0000000..ba3a84e --- /dev/null +++ b/losses.py @@ -0,0 +1,118 @@ +import torch +import torch.nn.functional as F +import numpy as np + +# image gradient computations +''' + Image gradient x-direction + \param + input_tensor + \return + input_tensor's x-direction gradients +''' +def grad_x(input_tensor): + input_tensor = F.pad(input_tensor, (0, 1, 0, 0), mode = "replicate") + gx = input_tensor[:, :, :, :-1] - input_tensor[:, :, :, 1:] + return gx + +''' + Image gradient y-direction + \param + input_tensor + \return + input_tensor's y-direction gradients +''' +def grad_y(input_tensor): + input_tensor = F.pad(input_tensor, (0, 0, 0, 1), mode = "replicate") + gy = input_tensor[:, :, :-1, :] - input_tensor[:, :, 1:, :] + return gy + +''' + L2 Loss + \param + input input tensor (model's prediction) + target target tensor (ground truth) + use_mask set True to compute masked loss + mask Binary mask tensor + \return + L2 loss mean between target and input + L2 loss map between target and input +''' +def l2_loss(input, target, use_mask = True, mask = None): + loss = torch.pow(target - input, 2) + if use_mask and mask is not None: + count = torch.sum(mask).item() + masked_loss = loss * mask + return torch.sum(masked_loss) / count, masked_loss + return torch.mean(loss), loss + +''' + Cosine Similarity loss (vector dot product) + \param + input input tensor (model's prediction) + target target tensor (ground truth) + use_mask set True to compute masked loss + mask Binary mask tensor + \return + Cosine similarity loss mean between target and input + Cosine similarity loss map betweem target and input +''' +def cosine_loss(input, target, use_mask = True, mask = None): + loss = 2 - (1 + torch.sum(input * target, dim = 1, keepdim = True)) + if use_mask and mask is not None: + count = torch.sum(mask) + masked_loss = loss * mask + return torch.sum(masked_loss) / count, masked_loss + return torch.mean(loss), loss +''' + Quaternion loss + \param + input input tensor (model's prediction) + target target tensor (ground truth) + use_mask set True to compute masked loss + mask Binary mask tensor + \return + Quaternion loss mean between target and input + Quaternion loss map betweem target and input +''' +def quaternion_loss(input, target, use_mask = True, mask = None): + q_pred = -input + loss_x = target[:, 1, :, :] * q_pred[:, 2, :, :] - target[:, 2, :, :] * q_pred[:, 1, :, :] + loss_y = target[:, 2, :, :] * q_pred[:, 0, :, :] - target[:, 0, :, :] * q_pred[:, 2, :, :] + loss_z = target[:, 0, :, :] * q_pred[:, 1, :, :] - target[:, 1, :, :] * q_pred[:, 0, :, :] + loss_re = -target[:, 0, :, :] * q_pred[:, 0, :, :] - target[:, 1, :, :] * q_pred[:, 1, :, :] - target[:, 2, :, :] * q_pred[:, 2, :, :] + loss_x = loss_x.unsqueeze(1) + loss_y = loss_y.unsqueeze(1) + loss_z = loss_z.unsqueeze(1) + loss_xyz = torch.cat((loss_x, loss_y, loss_z), 1) + + dot = loss_x * loss_x + loss_y * loss_y + loss_z * loss_z + eps = torch.ones_like(dot) * 1e-8 + + vec_diff = torch.sqrt(torch.max(dot, eps)) + real_diff = torch.sign(loss_re) * torch.abs(loss_re) + real_diff = real_diff.unsqueeze(1) + + loss = torch.atan2(vec_diff, real_diff) / (np.pi) + + if mask is not None: + count = torch.sum(mask) + mask = mask[:, 0, :, :].unsqueeze(1) + masked_loss = loss * mask + return torch.sum(masked_loss) / count, masked_loss + return torch.mean(loss) + +''' + Smoothness loss + \param + input input tensor (model's prediction) +''' +def smoothness_loss(input, use_mask = True, mask = None): + grads_x = grad_x(input) + grads_y = grad_y(input) + loss = torch.abs(grads_x) + torch.abs(grads_y) + if mask is not None: + count = torch.sum(mask).item() + masked_loss = mask * loss + return torch.sum(masked_loss) / count, masked_loss + return torch.mean(loss), loss \ No newline at end of file diff --git a/metrics.py b/metrics.py new file mode 100644 index 0000000..d3d4856 --- /dev/null +++ b/metrics.py @@ -0,0 +1,2 @@ +import numpy as np +import torch \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..9ba0d7c --- /dev/null +++ b/model.py @@ -0,0 +1,221 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +class VGG16Unet(nn.Module): + def __init__(self, features = None): + super(VGG16Unet, self).__init__() + # Encoder pre-trained features + self.features = features + # upsamplig block + self.up = nn.UpsamplingBilinear2d(scale_factor = 2.0) + # Rest blocks + # Encoder + self.enc_block0 = nn.Sequential( + nn.Conv2d(3, 64, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(64), + nn.ReLU(inplace = True), + nn.Conv2d(64, 64, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(64), + nn.ReLU(inplace = True) + ) + + self.enc_block1 = nn.Sequential( + nn.MaxPool2d(kernel_size = 2, stride = 2), + nn.Conv2d(64, 128, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(128), + nn.ReLU(inplace = True), + nn.Conv2d(128, 128, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(128), + nn.ReLU(inplace = True) + ) + + self.enc_block2 = nn.Sequential( + nn.MaxPool2d(kernel_size = 2, stride = 2), + nn.Conv2d(128, 256, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(256), + nn.ReLU(inplace = True), + nn.Conv2d(256, 256, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(256), + nn.ReLU(inplace = True), + nn.Conv2d(256, 256, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(256), + nn.ReLU(inplace = True) + ) + + self.enc_block3 = nn.Sequential( + nn.MaxPool2d(kernel_size = 2, stride = 2), + nn.Conv2d(256, 512, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True), + nn.Conv2d(512, 512, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True), + nn.Conv2d(512, 512, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True) + ) + + self.enc_block4 = nn.Sequential( + nn.MaxPool2d(kernel_size = 2, stride = 2), + nn.Conv2d(512, 512, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True), + nn.Conv2d(512, 512, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True), + nn.Conv2d(512, 512, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True) + ) + + self.bottle = nn.Sequential( + nn.MaxPool2d(kernel_size = 2, stride = 2), + nn.Conv2d(512, 512, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True) + ) + + # Decoder + # unet connection with enc_block5 (input will be 1024) + self.dec_block4 = nn.Sequential( + nn.Conv2d(1024, 512, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True), + nn.Conv2d(512, 512, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True), + nn.Conv2d(512, 512, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True) + ) + + # unet connection with enc_block5 (input will be 512) + self.dec_block3 = nn.Sequential( + nn.Conv2d(1024, 512, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True), + nn.Conv2d(512, 256, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(256), + nn.ReLU(inplace = True), + nn.Conv2d(256, 256, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(256), + nn.ReLU(inplace = True) + ) + + self.dec_block2 = nn.Sequential( + nn.Conv2d(512, 256, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(256), + nn.ReLU(inplace = True), + nn.Conv2d(256, 128, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(128), + nn.ReLU(inplace = True), + nn.Conv2d(128, 128, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(128), + nn.ReLU(inplace = True) + ) + + self.dec_block1 = nn.Sequential( + nn.Conv2d(256, 128, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(128), + nn.ReLU(inplace = True), + nn.Conv2d(128, 64, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(64), + nn.ReLU(inplace = True) + ) + + self.dec_block0 = nn.Sequential( + nn.Conv2d(128, 64, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(64), + nn.ReLU(inplace = True), + nn.Conv2d(64, 64, kernel_size = 3, padding = 1, stride = 1), + nn.BatchNorm2d(64), + nn.ReLU(inplace = True) + ) + + self.prediction = nn.Conv2d(64, 3, kernel_size = 3, padding = 1, stride = 1) + + def init_weights(self, init = 'xavier'): + if init == 'xavier': + init_func = torch.nn.init.xavier_normal_ + else: + init_func = torch.nn.init.normal_ + + if self.features is not None: + # pretrained weights initialization + self.enc_block0._modules['0'].weight = self.features._modules['0'].weight + self.enc_block0._modules['3'].weight = self.features._modules['2'].weight + self.enc_block1._modules['1'].weight = self.features._modules['5'].weight + self.enc_block1._modules['4'].weight = self.features._modules['7'].weight + self.enc_block2._modules['1'].weight = self.features._modules['10'].weight + self.enc_block2._modules['4'].weight = self.features._modules['12'].weight + self.enc_block2._modules['7'].weight = self.features._modules['14'].weight + self.enc_block3._modules['1'].weight = self.features._modules['17'].weight + self.enc_block3._modules['4'].weight = self.features._modules['19'].weight + self.enc_block3._modules['7'].weight = self.features._modules['21'].weight + self.enc_block4._modules['1'].weight = self.features._modules['24'].weight + self.enc_block4._modules['4'].weight = self.features._modules['26'].weight + self.enc_block4._modules['7'].weight = self.features._modules['28'].weight + else: + init_func(self.enc_block0._modules['0'].weight) + init_func(self.enc_block0._modules['3'].weight) + init_func(self.enc_block1._modules['1'].weight) + init_func(self.enc_block1._modules['4'].weight) + init_func(self.enc_block2._modules['1'].weight) + init_func(self.enc_block2._modules['4'].weight) + init_func(self.enc_block2._modules['7'].weight) + init_func(self.enc_block3._modules['1'].weight) + init_func(self.enc_block3._modules['4'].weight) + init_func(self.enc_block3._modules['7'].weight) + init_func(self.enc_block4._modules['1'].weight) + init_func(self.enc_block4._modules['4'].weight) + init_func(self.enc_block4._modules['7'].weight) + + #### Decoder Initialization + init_func(self.bottle._modules['1'].weight) + init_func(self.dec_block4._modules['0'].weight) + init_func(self.dec_block4._modules['3'].weight) + init_func(self.dec_block4._modules['6'].weight) + init_func(self.dec_block3._modules['0'].weight) + init_func(self.dec_block3._modules['3'].weight) + init_func(self.dec_block3._modules['6'].weight) + init_func(self.dec_block2._modules['0'].weight) + init_func(self.dec_block2._modules['3'].weight) + init_func(self.dec_block2._modules['6'].weight) + init_func(self.dec_block1._modules['0'].weight) + init_func(self.dec_block1._modules['3'].weight) + init_func(self.dec_block0._modules['0'].weight) + init_func(self.dec_block0._modules['3'].weight) + init_func(self.prediction.weight) + + def forward(self, x): + x0 = self.enc_block0(x) + x1 = self.enc_block1(x0) + x2 = self.enc_block2(x1) + x3 = self.enc_block3(x2) + x4 = self.enc_block4(x3) + b = self.bottle(x4) + + y4 = self.up(b) + y4 = torch.cat((x4, y4), dim = 1) + y4 = self.dec_block4(y4) + + y3 = self.up(y4) + y3 = torch.cat((x3, y3), dim = 1) + y3 = self.dec_block3(y3) + + y2 = self.up(y3) + y2 = torch.cat((x2, y2), dim = 1) + y2 = self.dec_block2(y2) + + y1 = self.up(y2) + y1 = torch.cat((x1, y1), dim = 1) + y1 = self.dec_block1(y1) + + y0 = self.up(y1) + y0 = torch.cat((x0, y0), dim = 1) + y0 = self.dec_block0(y0) + + pred = self.prediction(y0) + return pred \ No newline at end of file diff --git a/nn_utils.py b/nn_utils.py new file mode 100644 index 0000000..583f790 --- /dev/null +++ b/nn_utils.py @@ -0,0 +1,100 @@ +import torch +import torch.optim as optim +import torchvision.models +import random +import numpy as np +import os +import model as M + +class OptimParams(object): + def __init__(self, lr = 0.0001, momentum = 0.9, momentum2 = 0.999, eps = 1e-8, weight_decay = 0.0, damp = 0): + self.lr = lr + self.momentum = momentum + self.momentum2 = momentum2 + self.eps = eps + self.damp = damp + self.weight_decay = weight_decay + + def get_learning_rate(self): + return self.lr + + def get_momentum(self): + return self.momentum + + def get_momentum2(self): + return self.momentum2 + + def get_epsilon(self): + return self.eps + + def get_weight_decay(self): + return self.weight_decay + + def get_damp(self): + return self.damp + + +def get_optimizer(optim_type, model_params, optim_params): + if (optim_type == "adam"): + return optim.Adam( + model_params, + lr = optim_params.get_learning_rate(), + betas = (optim_params.get_momentum(), optim_params.get_momentum2()), + eps = optim_params.get_epsilon(), + weight_decay = optim_params.get_epsilon()) + else: + print("Error: Given optimizer type <{}>, is not valid".format(optim_type)) + +# def init_optimizer(optim_type, model, optim_params, optim_state = None): +# optimizer = get_optimizer(optim_type, model.parameters(), optim_params) +# if optim_state is not None: +# state = torch.load(optim_state) +# print("Loading previously saved optimizer state from {}".format(optim_state)) +# optimizer.load_state_dict(state) +# return optimizer + +def configure_device(gpus): + if (torch.cuda.is_available() and len(gpus) > 0 and gpus[0] >= 0): + device = torch.device('cuda:{}'.format(gpus[0])) + else: + device = torch.device('cpu') + print("Selected Device: {}".format(device)) + return device + + +def preseed(seed): + print("Preseeding for reproducibility with user seed: {}".format(seed)) + if seed > 0: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + np.random.seed(seed) + torch.backends.cudnn.deterministic = True + torch.cuda.benchmark = False + random.seed(seed) + +def load_model_and_optimizer_state(model, optimizer, chkp_path): + if os.path.exists(chkp_path) and optimizer is None: + print("Loading model checkpoint from: {}...".format(chkp_path)) + model_chkp = chkp_path + model_state = torch.load(model_chkp) + model.load_state_dict(model_state) + return model + if (os.path.exists(chkp_path)): + model_chkp = chkp_path + optim_chkp = chkp_path.replace("model", "optim") + print("Loading model checkpoint from: {}...".format(model_chkp)) + print("Loading optimizer checkpoint from: {}...".format(optim_chkp)) + model_state = torch.load(model_chkp) + optim_state = torch.load(optim_chkp) + model.load_state_dict(model_state) + optimizer.load_state_dict(optim_state) + return model, optim + +def load_model(trained = True): + features = None + if trained: + orig_vgg = torchvision.models.vgg16(pretrained = trained) + features = orig_vgg.features + return M.VGG16Unet(features) + + \ No newline at end of file diff --git a/settings_files/continue.json b/settings_files/continue.json new file mode 100644 index 0000000..e69de29 diff --git a/settings_files/test.json b/settings_files/test.json new file mode 100644 index 0000000..9fa1583 --- /dev/null +++ b/settings_files/test.json @@ -0,0 +1,10 @@ +{ + "session": { + "test_batch_size": 2, + "input_shape": [256, 512], + "test_filenames_filepath": "D:\\_dev\\_Projects\\360NormalsEstimation\\dataset_filenames\\360UD_test_.txt", + "gpu": [0], + "chkp_path": "D:\\_dev\\_Projects\\360NormalsEstimation\\experiments\\vgg16_unet_quaternion_smoothness\\chkp\\vgg16_unet_quaternion_smoothness_model_e50_b390000.chkp", + "session_name": "test_testing" + } +} \ No newline at end of file diff --git a/settings_files/train.json b/settings_files/train.json new file mode 100644 index 0000000..9d52eab --- /dev/null +++ b/settings_files/train.json @@ -0,0 +1,29 @@ +{ + "session": { + "optimizer": { + "optim": "adam", + "lr": 0.0002, + "weight_decay": 0.0, + "momentum": 0.9, + "momentum2": 0.999, + "epsilon": 1e-8 + }, + "pretrained": true, + "epochs": 50, + "train_batch_size": 4, + "eval_batch_size": 2, + "input_shape": [256, 512], + "train_filenames_filepath": "absolute\\path\\to\\train_filenames_file.txt", + "validation_filenames_filepath": "absolute\\path\\to\\validation_filenames_file.txt", + "seed": 1337, + "loss": { + "alpha": 0.025 + }, + "gpu": [0], + "chkp_path": "absolute\\path\\to\\directory\\to\\save\\checkpoint_files.chkp", + "display_iterations": 1, + "chkp_iterations": 10, + "evaluation_iterations": 5, + "session_name": "desired_session_name" + } +} \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..2c48b55 --- /dev/null +++ b/test.py @@ -0,0 +1,92 @@ +import os +import datetime +import argparse + +import torchvision +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data.dataloader import DataLoader + +import file_utils as fu +import nn_utils as nu +import losses +import dataset360N + +''' + Global argument parser +''' +arg_parser = argparse.ArgumentParser() +arg_parser.add_argument("--conf", type = str, help = "Absolute path to the configuration file") +arg_parser.add_argument("--log", type = str, help = "Directory to save log file") +# arguments +args = arg_parser.parse_args() +cli_args = vars(args) + +''' + Simple logger +''' +class Logger: + def __init__(self, log_filepath): + self.log_filepath = log_filepath + log_file = open(self.log_filepath, 'w') + log_file.close() + + + def log(self, message): + line = "{} | {}".format(datetime.datetime.now(), message) + print(line) + log_file = open(self.log_filepath, 'a') + line += "\n" + log_file.write(line) + log_file.close() + +''' + Global Logger object +''' +logger = Logger(cli_args['log']) + +def main(cli_args): + settings = fu.read_configuration_file(cli_args['conf']) + test(settings) + +def test(settings): + logger.log("Initializing testing...") + logger.log("Configuring Device...") + gpus = settings['session']['gpu'] + device = nu.configure_device(gpus) + + logger.log("Configuring model...") + model = nu.load_model(True) + if settings['session']['chkp_path']: + model = nu.load_model_and_optimizer_state(model, None, settings['session']['chkp_path']) + else: + logger.log("Failed to load pre-trained weights. No valid checkpoint path <{}> was given".format(settings['session']['chkp_path'])) + exit() + model.to(device) + + logger.log("Configuring data loader...") + test_bsize = settings['session']['test_batch_size'] + test_set = dataset360N.Dataset360N( + settings['session']['test_filenames_filepath'], + " ", + settings['session']['input_shape'] + ) + test_loader = DataLoader(test_set, batch_size = test_bsize, shuffle = False, pin_memory = True) + + logger.log("Testing...") + total_loss = torch.tensor(0.0).to(device) + for b_idx, test_sample in enumerate(test_loader): + active_loss = torch.tensor(0.0).to(device) + + rgb = test_sample['input_rgb'].to(device) + target = test_sample['target_surface'].to(device) + mask = test_sample['mask'].to(device) + + pred = model(rgb) + pred = F.normalize(pred, p = 2, dim = 1) + logger.log("Tested: {}".format(test_sample['filename'])) + logger.log("Testing finished.") + +if __name__ == "__main__": + main(cli_args) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..b3df2d1 --- /dev/null +++ b/train.py @@ -0,0 +1,184 @@ +import os +import datetime +import argparse + +import torchvision +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data.dataloader import DataLoader + +import file_utils as fu +import nn_utils as nu +import losses +import dataset360N + +''' + Global argument parser +''' +arg_parser = argparse.ArgumentParser() +arg_parser.add_argument("--conf", type = str, help = "Absolute path to the configuration file") +arg_parser.add_argument("--log", type = str, help = "Directory to save log file") +# arguments +args = arg_parser.parse_args() +cli_args = vars(args) + +''' + Simple logger +''' +class Logger: + def __init__(self, log_filepath): + self.log_filepath = log_filepath + log_file = open(self.log_filepath, 'w') + log_file.close() + + + def log(self, message): + line = "{} | {}".format(datetime.datetime.now(), message) + print(line) + log_file = open(self.log_filepath, 'a') + line += "\n" + log_file.write(line) + log_file.close() + +''' + Global Logger object +''' +logger = Logger(cli_args['log']) + +''' + main function +''' +def main(cli_args): + settings = fu.read_configuration_file(cli_args['conf']) + train(settings) + + +''' + Trains model +''' +def train(settings): + logger.log("Initializing training...") + + logger.log("Configuring Device...") + gpus = settings['session']['gpu'] + device = nu.configure_device(gpus) + + logger.log("Configuring Model...") + model = None + if settings['session']['pretrained']: + logger.log("Loading pre-trained weights...") + model = nu.load_"model(True) + else: + model = nu.load_model(False) + logger.log("Initializing model weights (using Xavier initialization)...") + model.init_weights() + model.to(device) + + logger.log("Configuring Optimizer...") + optim_settings = settings["session"]["optimizer"] + opt = optim_settings['optim'] + lr = optim_settings['lr'] + wd = optim_settings['weight_decay'] + mom = optim_settings['momentum'] + mom2 = optim_settings['momentum2'] + eps = optim_settings['epsilon'] + optim_params = nu.OptimParams(lr, mom, mom2, eps, wd) + optimizer = nu.get_optimizer(opt, model.parameters(), optim_params) + # optimizer.to(device) + + logger.log("Preseeding...") + nu.preseed(settings['session']['seed']) + + # make train loader + logger.log("Configuring data loader...") + train_bsize = settings['session']['train_batch_size'] + eval_bsize = settings['session']['eval_batch_size'] + train_set = dataset360N.Dataset360N( + settings["session"]["train_filenames_filepath"], + " ", + settings["session"]["input_shape"]) + eval_set = dataset360N.Dataset360N( + settings["session"]["validation_filenames_filepath"], + " ", + settings["session"]["input_shape"]) + train_loader = DataLoader(train_set, batch_size = train_bsize, shuffle = True, pin_memory = True) + eval_loader = DataLoader(eval_set, batch_size = eval_bsize, shuffle = True, pin_memory = True) + + epochs = settings['session']["epochs"] + epoch_range = range(epochs) + disp_iters = settings['session']["display_iterations"] + chkp_iters = settings['session']["chkp_iterations"] + eval_iters = settings['session']["evaluation_iterations"] + chkp_path = settings['session']["chkp_path"] + sess_name = settings['session']["session_name"] + alpha = settings['session']["loss"]["alpha"] + + logger.log("Training...") + g_iters = 0 + for e in epoch_range: + for b_idx, train_sample in enumerate(train_loader): + active_loss = torch.tensor(0.0).to(device) + quaternion_loss = 0.0 + smoothness_loss = 0.0 + + rgb = train_sample["input_rgb"].to(device) + target = train_sample["target_surface"].to(device) + mask = train_sample["mask"].to(device) + + pred = model(rgb) + pred = F.normalize(pred, p = 2, dim = 1) + + quat_loss, quat_loss_map = losses.quaternion_loss(pred, target, True, mask) + quaternion_loss += quat_loss * (1 - alpha) + smooth_loss, smooth_loss_map = losses.smoothness_loss(pred, True, mask) + smoothness_loss += smooth_loss * alpha + + active_loss = quat_loss * (1 - alpha) + smooth_loss * (alpha) + + optimizer.zero_grad() + active_loss.backward() + optimizer.step() + + g_iters += train_bsize + if g_iters % chkp_iters == 0: + logger.log("Saving Checkpoint in: {}".format(chkp_path)) + fu.save_state(chkp_path, sess_name, model, optimizer, e + 1, g_iters) + if g_iters % disp_iters == 0: + logger.log("Epoch: {} | Training iter: {} | Training Loss:".format(e + 1, g_iters)) + logger.log("\t\t\t\t\tTotal Loss : {}".format(active_loss)) + logger.log("\t\t\t\t\tQuaternion Loss: {}".format(quat_loss)) + logger.log("\t\t\t\t\tSmoothness Loss: {}".format(smooth_loss)) + if g_iters % eval_iters == 0: + logger.log("Evaluating...") + model.eval() + eval_loss = 0.0 + counter = 0.0 + with torch.no_grad(): + active_loss = torch.tensor(0.0).to(device) + for eval_b_idx, eval_sample in enumerate(eval_loader): + quaternion_loss = 0.0 + smoothness_loss = 0.0 + + rgb = train_sample["input_rgb"].to(device) + target = train_sample["target_surface"].to(device) + mask = train_sample["mask"].to(device) + + pred = model(rgb) + pred = F.normalize(pred, p = 2, dim = 1) + + quat_loss, quat_loss_map = losses.quaternion_loss(pred, target, True, mask) + quaternion_loss += quat_loss * (1 - alpha) + smooth_loss, smooth_loss_map = losses.smoothness_loss(pred, True, mask) + smoothness_loss += smooth_loss * alpha + + active_loss += quat_loss * (1 - alpha) + smooth_loss * (alpha) + counter += eval_bsize + total_loss = active_loss / counter + logger.log("Evaluation finished. Total Loss: {}".format(total_loss)) + logger.log("Epoch {} finished.".format(e + 1)) + logger.log("Training session finished.") + + +if __name__ == "__main__": + main(cli_args) \ No newline at end of file