-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
904 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import os | ||
import torch | ||
import numpy as np | ||
import cv2 | ||
from torch.utils.data import Dataset | ||
from torchvision import transforms, utils | ||
import torch.nn.functional as F | ||
import PIL.Image as Image | ||
import random | ||
|
||
|
||
class Dataset360N(Dataset): | ||
def __init__(self, filenames_filepath, delimiter, input_shape): | ||
self.length = 0 | ||
self.height = input_shape[0] | ||
self.width = input_shape[1] | ||
self.filenames_filepath = filenames_filepath | ||
self.delim = delimiter | ||
self.data_paths = {} | ||
self.init_data_dict() | ||
self.gather_filepaths() | ||
|
||
def init_data_dict(self): | ||
self.data_paths = { | ||
"rgb": [], | ||
"surface": [] | ||
} | ||
|
||
def gather_filepaths(self): | ||
fd = open(self.filenames_filepath, 'r') | ||
lines = fd.readlines() | ||
for line in lines: | ||
splits = line.split(self.delim) | ||
self.data_paths["rgb"].append(splits[0]) | ||
# TODO: check for filenames files format | ||
self.data_paths["surface"].append(splits[6]) | ||
fd.close() | ||
assert len(self.data_paths["rgb"]) == len(self.data_paths["surface"]) | ||
self.length = len(self.data_paths["rgb"]) | ||
|
||
def load_rgb(self, filepath): | ||
if not os.path.exists(filepath): | ||
print("\tGiven filepath <{}> does not exist".format(filepath)) | ||
return np.zeros((self.height, self.width, 3), dtype = np.float32) | ||
rgb_np = cv2.imread(filepath, cv2.IMREAD_ANYCOLOR) | ||
return rgb_np | ||
|
||
def load_float(self, filepath): | ||
if not os.path.exists(filepath): | ||
print("\tGiven filepath <{}> does not exist".format(filepath)) | ||
return np.zeros((self.height, self.width, 3), dtype = np.float32) | ||
surface_np = cv2.imread(filepath, cv2.IMREAD_UNCHANGED) | ||
# Creates mask for invalid values | ||
surface_np[np.isnan(surface_np)] = 0.0 | ||
mask_np = np.ones_like(surface_np) | ||
mask_np[np.sum(surface_np, 2) == 0.0] = 0.0 | ||
return surface_np, mask_np | ||
|
||
def clean_normal(self, normal): | ||
# check if normals are close to the dominant | ||
# coord system normals | ||
shape = normal.shape | ||
vecs = [ | ||
[1, 0, 0], | ||
[0, 1, 0], | ||
[0, 0 ,1], | ||
[-1, 0, 0], | ||
[0, -1, 0], | ||
[0, 0, -1] | ||
] | ||
for vec in vecs: | ||
vec_mat = np.asarray(vec, dtype = np.float32) | ||
vec_mat = np.expand_dims(vec_mat, 0) | ||
vec_mat = np.expand_dims(vec_mat, 1) | ||
vec_mat = vec_mat.repeat(shape[0], 0) | ||
vec_mat = vec_mat.repeat(shape[1], 1) | ||
inds = np.isclose(normal, vec_mat, 0.0001, 0.1) | ||
inds = inds[:, :, 0] & inds[:, :, 1] & inds[:, :, 2] | ||
normal[inds, 0] = vec[0] | ||
normal[inds, 1] = vec[1] | ||
normal[inds, 2] = vec[2] | ||
return normal | ||
|
||
def make_tensor(self, np_array): | ||
np_array = np_array.transpose(2, 0, 1) | ||
tensor = torch.from_numpy(np_array) | ||
return torch.as_tensor(tensor, dtype = torch.float32) | ||
|
||
def load_item(self, idx): | ||
item = { } | ||
if (idx >= self.length): | ||
print("Index out of range.") | ||
else: | ||
rgb_np = self.load_rgb(self.data_paths["rgb"][idx]) | ||
surface_np, mask_np = self.load_float(self.data_paths["surface"][idx]) | ||
surface_np = self.clean_normal(surface_np) | ||
rgb = self.make_tensor(rgb_np) | ||
surface = self.make_tensor(surface_np) | ||
surface = F.normalize(surface, p = 2, dim = 1) | ||
mask = self.make_tensor(mask_np) | ||
item['input_rgb'] = rgb | ||
item['target_surface'] = surface | ||
item['mask'] = mask | ||
item['filename'] = os.path.basename(self.data_paths["surface"][idx]) | ||
return item | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
def __getitem__(self, idx): | ||
return self.load_item(idx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import os | ||
import json | ||
import torch | ||
|
||
''' | ||
Reads configuration file | ||
\param | ||
filepath the absolute path to the configuration file | ||
\return | ||
settings_map dictionary with the configuration settings | ||
''' | ||
def read_configuration_file(filepath): | ||
print("Reading configuration file...") | ||
settings = {} | ||
if os.path.exists(filepath): | ||
with open(filepath, 'r') as fd: | ||
settings = json.load(fd) | ||
assert settings['session'], print("Failed to read configuration file. No session settings.") | ||
# assert settings['session']['optimizer'], print("Failed to read configuration file. No optimizer settings.") | ||
return settings | ||
|
||
def save_state(directory, session_name, model, optimizer, epoch, global_iters): | ||
if os.path.isfile(directory): | ||
directory = os.path.abspath(os.path.dirname(directory)) | ||
model_state_dict = model.state_dict() | ||
optim_state_dict = optimizer.state_dict() | ||
model_filename = session_name + "_model_e_{}_b_{}.chkp".format(epoch, global_iters) | ||
optim_filename = session_name + "_optim_e_{}_b_{}.chkp".format(epoch, global_iters) | ||
model_filepath = os.path.join(directory, model_filename) | ||
optim_filepath = os.path.join(directory, optim_filename) | ||
torch.save(model_state_dict, model_filepath) | ||
torch.save(optim_state_dict, optim_filepath) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
import numpy as np | ||
|
||
# image gradient computations | ||
''' | ||
Image gradient x-direction | ||
\param | ||
input_tensor | ||
\return | ||
input_tensor's x-direction gradients | ||
''' | ||
def grad_x(input_tensor): | ||
input_tensor = F.pad(input_tensor, (0, 1, 0, 0), mode = "replicate") | ||
gx = input_tensor[:, :, :, :-1] - input_tensor[:, :, :, 1:] | ||
return gx | ||
|
||
''' | ||
Image gradient y-direction | ||
\param | ||
input_tensor | ||
\return | ||
input_tensor's y-direction gradients | ||
''' | ||
def grad_y(input_tensor): | ||
input_tensor = F.pad(input_tensor, (0, 0, 0, 1), mode = "replicate") | ||
gy = input_tensor[:, :, :-1, :] - input_tensor[:, :, 1:, :] | ||
return gy | ||
|
||
''' | ||
L2 Loss | ||
\param | ||
input input tensor (model's prediction) | ||
target target tensor (ground truth) | ||
use_mask set True to compute masked loss | ||
mask Binary mask tensor | ||
\return | ||
L2 loss mean between target and input | ||
L2 loss map between target and input | ||
''' | ||
def l2_loss(input, target, use_mask = True, mask = None): | ||
loss = torch.pow(target - input, 2) | ||
if use_mask and mask is not None: | ||
count = torch.sum(mask).item() | ||
masked_loss = loss * mask | ||
return torch.sum(masked_loss) / count, masked_loss | ||
return torch.mean(loss), loss | ||
|
||
''' | ||
Cosine Similarity loss (vector dot product) | ||
\param | ||
input input tensor (model's prediction) | ||
target target tensor (ground truth) | ||
use_mask set True to compute masked loss | ||
mask Binary mask tensor | ||
\return | ||
Cosine similarity loss mean between target and input | ||
Cosine similarity loss map betweem target and input | ||
''' | ||
def cosine_loss(input, target, use_mask = True, mask = None): | ||
loss = 2 - (1 + torch.sum(input * target, dim = 1, keepdim = True)) | ||
if use_mask and mask is not None: | ||
count = torch.sum(mask) | ||
masked_loss = loss * mask | ||
return torch.sum(masked_loss) / count, masked_loss | ||
return torch.mean(loss), loss | ||
''' | ||
Quaternion loss | ||
\param | ||
input input tensor (model's prediction) | ||
target target tensor (ground truth) | ||
use_mask set True to compute masked loss | ||
mask Binary mask tensor | ||
\return | ||
Quaternion loss mean between target and input | ||
Quaternion loss map betweem target and input | ||
''' | ||
def quaternion_loss(input, target, use_mask = True, mask = None): | ||
q_pred = -input | ||
loss_x = target[:, 1, :, :] * q_pred[:, 2, :, :] - target[:, 2, :, :] * q_pred[:, 1, :, :] | ||
loss_y = target[:, 2, :, :] * q_pred[:, 0, :, :] - target[:, 0, :, :] * q_pred[:, 2, :, :] | ||
loss_z = target[:, 0, :, :] * q_pred[:, 1, :, :] - target[:, 1, :, :] * q_pred[:, 0, :, :] | ||
loss_re = -target[:, 0, :, :] * q_pred[:, 0, :, :] - target[:, 1, :, :] * q_pred[:, 1, :, :] - target[:, 2, :, :] * q_pred[:, 2, :, :] | ||
loss_x = loss_x.unsqueeze(1) | ||
loss_y = loss_y.unsqueeze(1) | ||
loss_z = loss_z.unsqueeze(1) | ||
loss_xyz = torch.cat((loss_x, loss_y, loss_z), 1) | ||
|
||
dot = loss_x * loss_x + loss_y * loss_y + loss_z * loss_z | ||
eps = torch.ones_like(dot) * 1e-8 | ||
|
||
vec_diff = torch.sqrt(torch.max(dot, eps)) | ||
real_diff = torch.sign(loss_re) * torch.abs(loss_re) | ||
real_diff = real_diff.unsqueeze(1) | ||
|
||
loss = torch.atan2(vec_diff, real_diff) / (np.pi) | ||
|
||
if mask is not None: | ||
count = torch.sum(mask) | ||
mask = mask[:, 0, :, :].unsqueeze(1) | ||
masked_loss = loss * mask | ||
return torch.sum(masked_loss) / count, masked_loss | ||
return torch.mean(loss) | ||
|
||
''' | ||
Smoothness loss | ||
\param | ||
input input tensor (model's prediction) | ||
''' | ||
def smoothness_loss(input, use_mask = True, mask = None): | ||
grads_x = grad_x(input) | ||
grads_y = grad_y(input) | ||
loss = torch.abs(grads_x) + torch.abs(grads_y) | ||
if mask is not None: | ||
count = torch.sum(mask).item() | ||
masked_loss = mask * loss | ||
return torch.sum(masked_loss) / count, masked_loss | ||
return torch.mean(loss), loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
import numpy as np | ||
import torch |
Oops, something went wrong.