Skip to content

Commit

Permalink
train done
Browse files Browse the repository at this point in the history
  • Loading branch information
soobinseo committed Sep 6, 2023
1 parent 65d7067 commit ffe30e1
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 86 deletions.
6 changes: 3 additions & 3 deletions code/confs/dataset/video.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
metainfo:
gender: 'female'
data_dir : data_equi/
subject: "data_equi"
data_dir : data_bbox_2d_init_training_no_pose_1e4/
subject: "data_bbox_2d_init_training_no_pose_1e4"
start_frame: 0
end_frame: 41

Expand All @@ -12,7 +12,7 @@ train:
shuffle: True
worker: 8

num_sample : 512
num_sample : 1024

valid:
type: "VideoVal"
Expand Down
5 changes: 3 additions & 2 deletions code/confs/model/model_w_bg.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
learning_rate : 5.0e-4
body_param_learning_rate : 1.0e-4
sched_milestones : [200,500]
sched_factor : 0.5
smpl_init: True
Expand All @@ -21,7 +22,6 @@ implicit_network:
embedder_mode: 'fourier'
multires: 6
cond: 'smpl'
scene_bounding_sphere: 3.0
rendering_network:
feature_vector_size: 256
mode: "pose"
Expand Down Expand Up @@ -61,7 +61,7 @@ density:
params_init: {beta: 0.1}
beta_min: 0.0001
ray_sampler:
near: 0.0
near: 0.5
N_samples: 64
N_samples_eval: 128
N_samples_extra: 32
Expand All @@ -75,3 +75,4 @@ loss:
bce_weight: 5.0e-3
opacity_sparse_weight: 3.0e-3
in_shape_weight: 1.0e-2
mask_weight: 1.0e-3
48 changes: 21 additions & 27 deletions code/lib/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,32 +75,13 @@ def __init__(self, metainfo, split):
# cameras
camera_poses = np.load(os.path.join(root, "camera_pos.npy"))
camera_rotates = np.load(os.path.join(root, "camera_rotate.npy"))
# camera_dict = np.load(os.path.join(root, "cameras_normalize.npz"))
# scale_mats = [
# camera_dict["scale_mat_%d" % idx].astype(np.float32)
# for idx in self.training_indices
# ]
# world_mats = [
# camera_dict["world_mat_%d" % idx].astype(np.float32)
# for idx in self.training_indices
# ]

# self.scale = 1 / scale_mats[0][0, 0]
self.scale = 1

self.camera_poses = torch.tensor(camera_poses, dtype=torch.float32)
self.camera_rotates = torch.tensor(camera_rotates, dtype=torch.float32)

# self.intrinsics_all = []
# self.pose_all = []
# for scale_mat, world_mat in zip(scale_mats, world_mats):
# P = world_mat @ scale_mat
# P = P[:3, :4]
# intrinsics, pose = utils.load_K_Rt_from_P(None, P)
# self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
# self.pose_all.append(torch.from_numpy(pose).float())
# assert len(self.intrinsics_all) == len(self.pose_all)

# other properties
self.num_sample = split.num_sample
self.sampling_strategy = "weighted"
Expand All @@ -117,12 +98,15 @@ def __getitem__(self, idx):

mask = cv2.imread(self.mask_paths[idx])
# preprocess: BGR -> Gray -> Mask
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) > 0
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
dilate_kernel = np.ones((20, 20), np.uint8)
mask_for_sampling = cv2.dilate(mask, dilate_kernel)

img_size = self.img_size
mask = mask > 0
mask_for_sampling = mask > 0
# mask = mask / 255.0

# uv = np.mgrid[: img_size[0], : img_size[1]].astype(np.int32)
# uv = np.flip(uv, axis=0).copy().transpose(1, 2, 0).astype(np.float32)
img_size = self.img_size

x = np.linspace(-0.5, 0.5, img_size[0], endpoint=False)
y = np.linspace(-0.5, 0.5, img_size[1], endpoint=False)
Expand All @@ -140,7 +124,8 @@ def __getitem__(self, idx):
data = {
"rgb": img,
"uv": uv,
"object_mask": mask,
"object_mask": mask_for_sampling,
"mask": mask,
}

samples, index_outside = utils.weighted_sampling(
Expand All @@ -153,10 +138,13 @@ def __getitem__(self, idx):
"camera_poses": self.camera_poses[idx],
"camera_rotates": self.camera_rotates[idx],
"smpl_params": smpl_params,
"index_outside": index_outside,
# "index_outside": index_outside,
"idx": idx,
}
images = {"rgb": samples["rgb"].astype(np.float32)}
images = {
"rgb": samples["rgb"].astype(np.float32),
"mask": samples["mask"].astype(np.float32),
}
return inputs, images
else:
inputs = {
Expand All @@ -171,6 +159,7 @@ def __getitem__(self, idx):
images = {
"rgb": img.reshape(-1, 3).astype(np.float32),
"img_size": self.img_size,
"mask": mask.reshape(-1).astype(np.float32),
}
return inputs, images

Expand Down Expand Up @@ -206,6 +195,7 @@ def __getitem__(self, idx):
"img_size": images["img_size"],
"pixel_per_batch": self.pixel_per_batch,
"total_pixels": self.total_pixels,
"mask": images["mask"],
}
return inputs, images

Expand Down Expand Up @@ -235,5 +225,9 @@ def __getitem__(self, idx):
"smpl_params": inputs["smpl_params"],
"idx": inputs["idx"],
}
images = {"rgb": images["rgb"], "img_size": images["img_size"]}
images = {
"rgb": images["rgb"],
"img_size": images["img_size"],
"mask": images["mask"],
}
return inputs, images, self.pixel_per_batch, self.total_pixels, idx
114 changes: 84 additions & 30 deletions code/lib/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,117 @@
from torch import nn
from torch.nn import functional as F


class Loss(nn.Module):
def __init__(self, opt):
super().__init__()
self.eikonal_weight = opt.eikonal_weight
self.bce_weight = opt.bce_weight
self.opacity_sparse_weight = opt.opacity_sparse_weight
self.in_shape_weight = opt.in_shape_weight
self.mask_weight = opt.mask_weight
self.eps = 1e-6
self.milestone = 200
self.l1_loss = nn.L1Loss(reduction='mean')
self.l2_loss = nn.MSELoss(reduction='mean')
self.l1_loss = nn.L1Loss(reduction="mean")
self.l2_loss = nn.MSELoss(reduction="mean")

# L1 reconstruction loss for RGB values
def get_rgb_loss(self, rgb_values, rgb_gt):
rgb_loss = self.l1_loss(rgb_values, rgb_gt)
return rgb_loss

# Eikonal loss introduced in IGR
def get_eikonal_loss(self, grad_theta):
eikonal_loss = ((grad_theta.norm(2, dim=-1) - 1)**2).mean()
eikonal_loss = ((grad_theta.norm(2, dim=-1) - 1) ** 2).mean()
return eikonal_loss

# BCE loss for clear boundary
def get_bce_loss(self, acc_map):
binary_loss = -1 * (acc_map * (acc_map + self.eps).log() + (1-acc_map) * (1 - acc_map + self.eps).log()).mean() * 2
def get_bce_loss(self, acc_map, mask=None):
if mask is not None:
gt_acc_map = mask
else:
gt_acc_map = acc_map
binary_loss = (
-1
* (
gt_acc_map * (acc_map + self.eps).log()
+ (1 - gt_acc_map) * (1 - acc_map + self.eps).log()
).mean()
* 2
)
return binary_loss

# Global opacity sparseness regularization
# Global opacity sparseness regularization
def get_opacity_sparse_loss(self, acc_map, index_off_surface):
opacity_sparse_loss = self.l1_loss(acc_map[index_off_surface], torch.zeros_like(acc_map[index_off_surface]))
opacity_sparse_loss = self.l1_loss(
acc_map[index_off_surface], torch.zeros_like(acc_map[index_off_surface])
)
return opacity_sparse_loss

# Optional: This loss helps to stablize the training in the very beginning
def get_in_shape_loss(self, acc_map, index_in_surface):
in_shape_loss = self.l1_loss(acc_map[index_in_surface], torch.ones_like(acc_map[index_in_surface]))
in_shape_loss = self.l1_loss(
acc_map[index_in_surface], torch.ones_like(acc_map[index_in_surface])
)
return in_shape_loss

def get_mask_loss(self, acc_map, target_mask):
mask_loss = self.l1_loss(acc_map, target_mask)
return mask_loss

def forward(self, model_outputs, ground_truth):
nan_filter = ~torch.any(model_outputs['rgb_values'].isnan(), dim=1)
rgb_gt = ground_truth['rgb'][0].cuda()
rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'][nan_filter], rgb_gt[nan_filter])
eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])
bce_loss = self.get_bce_loss(model_outputs['acc_map'])
opacity_sparse_loss = self.get_opacity_sparse_loss(model_outputs['acc_map'], model_outputs['index_off_surface'])
in_shape_loss = self.get_in_shape_loss(model_outputs['acc_map'], model_outputs['index_in_surface'])
curr_epoch_for_loss = min(self.milestone, model_outputs['epoch']) # will not increase after the milestone
nan_filter = ~torch.any(model_outputs["rgb_values"].isnan(), dim=1)
rgb_gt = ground_truth["rgb"][0].cuda()

# inside_idx = (ground_truth["mask"] > 0).squeeze(0)
# rgb_inside_loss = self.get_rgb_loss(
# model_outputs["rgb_values"][nan_filter][inside_idx],
# rgb_gt[nan_filter][inside_idx],
# )
# rgb_outside_loss = self.get_rgb_loss(
# model_outputs["rgb_values"][nan_filter][~inside_idx],
# rgb_gt[nan_filter][~inside_idx],
# )
# rgb_loss = 0.9 * rgb_inside_loss + 0.1 * rgb_outside_loss

rgb_loss = self.get_rgb_loss(
model_outputs["rgb_values"][nan_filter],
rgb_gt[nan_filter],
)
eikonal_loss = self.get_eikonal_loss(model_outputs["grad_theta"])
if model_outputs["epoch"] < 20 or model_outputs["epoch"] % 20 == 0:
bce_loss = self.get_bce_loss(model_outputs["acc_map"])
else:
bce_loss = self.get_bce_loss(model_outputs["acc_map"], mask=None)
opacity_sparse_loss = self.get_opacity_sparse_loss(
model_outputs["acc_map"], model_outputs["index_off_surface"]
)
in_shape_loss = self.get_in_shape_loss(
model_outputs["acc_map"], model_outputs["index_in_surface"]
)
mask_loss = self.get_mask_loss(model_outputs["acc_map"], ground_truth["mask"])
curr_epoch_for_loss = min(
self.milestone, model_outputs["epoch"]
) # will not increase after the milestone

loss = rgb_loss + \
self.eikonal_weight * eikonal_loss + \
self.bce_weight * bce_loss + \
self.opacity_sparse_weight * (1 + curr_epoch_for_loss ** 2 / 40) * opacity_sparse_loss + \
self.in_shape_weight * (1 - curr_epoch_for_loss / self.milestone) * in_shape_loss
loss = (
rgb_loss
+ self.eikonal_weight * eikonal_loss
+ self.bce_weight * bce_loss
+ self.opacity_sparse_weight
* (1 + curr_epoch_for_loss**2 / 40)
* opacity_sparse_loss
+ self.in_shape_weight
* (1 - curr_epoch_for_loss / self.milestone)
* in_shape_loss
# + self.mask_weight * mask_loss
)
return {
'loss': loss,
'rgb_loss': rgb_loss,
'eikonal_loss': eikonal_loss,
'bce_loss': bce_loss,
'opacity_sparse_loss': opacity_sparse_loss,
'in_shape_loss': in_shape_loss,
}
"loss": loss,
"rgb_loss": rgb_loss,
"eikonal_loss": eikonal_loss,
"bce_loss": bce_loss,
"opacity_sparse_loss": opacity_sparse_loss,
"in_shape_loss": in_shape_loss,
"mask_loss": mask_loss,
}
10 changes: 6 additions & 4 deletions code/lib/model/v2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def forward(self, input):
camera_pos = input["camera_poses"]
camera_rotate = input["camera_rotates"]

scale = input["smpl_params"][:, 0]
scale = input["smpl_params"][:, 0][:, None]
smpl_pose = input["smpl_pose"]
smpl_shape = input["smpl_shape"]
smpl_trans = input["smpl_trans"]
Expand All @@ -127,8 +127,10 @@ def forward(self, input):
if self.training:
if input["current_epoch"] < 20 or input["current_epoch"] % 20 == 0:
cond = {"smpl": smpl_pose[:, 3:] * 0.0}
# ray_dirs, cam_loc = utils.get_camera_params(uv, pose, intrinsics)
ray_dirs, cam_loc = utils.get_camera_params_equirect(uv, camera_pos=camera_pos, camera_rotate=camera_rotate)

ray_dirs, cam_loc = utils.get_camera_params_equirect(
uv, camera_pos=camera_pos, camera_rotate=camera_rotate
)
batch_size, num_pixels, _ = ray_dirs.shape

cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)
Expand Down Expand Up @@ -272,7 +274,7 @@ def forward(self, input):
"points": points,
"rgb_values": rgb_values,
"normal_values": normal_values,
"index_outside": input["index_outside"],
# "index_outside": input["index_outside"],
"index_off_surface": index_off_surface,
"index_in_surface": index_in_surface,
"acc_map": torch.sum(weights, -1),
Expand Down
38 changes: 28 additions & 10 deletions code/lib/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn import functional as F
import math
from pytorch3d.transforms import euler_angles_to_matrix


def split_input(model_input, total_pixels, n_pixels=10000):
Expand Down Expand Up @@ -231,19 +232,32 @@ def get_camera_params_equirect(uv, camera_pos, camera_rotate):
camera_yaw = camera_rotate[..., 1]
camera_roll = camera_rotate[..., 2]

pitch = -camera_pitch
yaw = -camera_yaw
roll = -camera_roll
pitch = camera_pitch.unsqueeze(-1)
yaw = camera_yaw.unsqueeze(-1)
roll = camera_roll.unsqueeze(-1)

direc_cam = direc_cam.view(
-1, 3
) # (batch, num_samples, 3) --> (batch * num_samples, 3)
# direc_cam = direc_cam.view(
# -1, 3
# ) # (batch, num_samples, 3) --> (batch * num_samples, 3)
direc_cam = direc_cam.unsqueeze(-1)

direc_world = rotation_roll(direc_cam, roll)
direc_world = rotation_pitch(direc_world, pitch)
direc_world = rotation_yaw(direc_world, yaw)
R_world_to_camera = euler_angles_to_matrix(
torch.cat([roll, pitch, yaw], dim=-1), convention="ZXY"
)
# R_world_to_camera = R.from_euler(
# "zxy", (roll, pitch, yaw), degrees=False
# ).as_matrix()
R_world_to_camera = R_world_to_camera.unsqueeze(1).repeat(1, num_samples, 1, 1)
direc_world = torch.einsum("bnij,bnjk->bnik", R_world_to_camera, direc_cam).squeeze(
-1
)
# direc_world = torch.bmm(R_world_to_camera, direc_cam.unsqueeze(-1)).squeeze(-1)

# direc_world = rotation_roll(direc_cam, roll)
# direc_world = rotation_pitch(direc_world, pitch)
# direc_world = rotation_yaw(direc_world, yaw)

direc_world = direc_world.view(batch_size, num_samples, 3)
# direc_world = direc_world.view(batch_size, num_samples, 3)

return direc_world, camera_pos

Expand Down Expand Up @@ -404,6 +418,10 @@ def weighted_sampling(data, img_size, num_sample, bbox_ratio=0.9):
bbox_max = where.max(axis=1)

num_sample_bbox = int(num_sample * bbox_ratio)
# samples_bbox_indices = np.random.choice(
# list(range(where.shape[1])), size=num_sample_bbox, replace=False
# )
# samples_bbox = where[:, samples_bbox_indices].transpose()
samples_bbox = np.random.rand(num_sample_bbox, 2)
samples_bbox = samples_bbox * (bbox_max - bbox_min) + bbox_min

Expand Down
Loading

0 comments on commit ffe30e1

Please sign in to comment.