Skip to content

Commit

Permalink
Merge branch 'feature/apply_stereo' into feature/optim_test
Browse files Browse the repository at this point in the history
  • Loading branch information
andabi committed Sep 25, 2023
2 parents cee8c79 + afd023f commit 9445f43
Show file tree
Hide file tree
Showing 11 changed files with 462 additions and 264 deletions.
2 changes: 1 addition & 1 deletion code/confs/dataset/video.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ valid:
image_id: 0
batch_size: 1
drop_last: False
shuffle: False
shuffle: True
worker: 8

num_sample : -1
Expand Down
6 changes: 3 additions & 3 deletions code/confs/model/model_w_bg.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
learning_rate : 5.0e-4
body_param_learning_rate : 1.0e-4
body_param_learning_rate : 1.0e-3
sched_milestones : [200,500]
sched_factor : 0.5
smpl_init: True
smpl_init: False
is_continue: False
use_body_parsing: False
with_bkgd: True
Expand Down Expand Up @@ -61,7 +61,7 @@ density:
params_init: {beta: 0.1}
beta_min: 0.0001
ray_sampler:
near: 0.5
near: 0.0
N_samples: 64
N_samples_eval: 128
N_samples_extra: 32
Expand Down
120 changes: 83 additions & 37 deletions code/lib/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
from lib.utils import utils
from scipy.spatial.transform import Rotation as R

os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"

Expand Down Expand Up @@ -42,6 +43,25 @@ def flip_y_2d_points(points, y_range=[-1.0, 1.0]):
return new_points


def get_right_camera_pos(camera_poses, camera_rotates):
right_shift = np.array([0.06, 0, 0])

right_shifts = []
for camera_rotate in camera_rotates:
camera_pitch = camera_rotate[..., 0] # (1,1)
camera_yaw = camera_rotate[..., 1] # (1,1)
camera_roll = camera_rotate[..., 2] # (1,1)

R_world_to_camera = R.from_euler(
"zxy", (camera_roll, camera_pitch, camera_yaw), degrees=False
).as_matrix()

right_shift_world = np.matmul(R_world_to_camera, right_shift)
right_shifts.append(right_shift_world)
right_shifts = np.stack(right_shifts)
return camera_poses + right_shifts


class Dataset(torch.utils.data.Dataset):
def __init__(self, metainfo, split):
root = os.path.join(".", metainfo.data_dir)
Expand All @@ -50,41 +70,65 @@ def __init__(self, metainfo, split):
self.start_frame = metainfo.start_frame
self.end_frame = metainfo.end_frame
self.skip_step = 1
self.training_indices = list(
range(metainfo.start_frame, metainfo.end_frame, self.skip_step)
)

# images
img_dir = os.path.join(root, "image")
self.img_paths = sorted(glob.glob(f"{img_dir}/*.png"))
img_left_dir = os.path.join(root, "image", "left")
img_right_dir = os.path.join(root, "image", "right")

img_left_paths = sorted(glob.glob(f"{img_left_dir}/*.png"))
img_right_paths = sorted(glob.glob(f"{img_right_dir}/*.png"))

assert len(img_left_paths) == len(img_right_paths)

self.img_paths = sorted(glob.glob(f"{img_left_dir}/*.png")) + sorted(
glob.glob(f"{img_right_dir}/*.png")
)

# only store the image paths to avoid OOM
self.img_paths = [self.img_paths[i] for i in self.training_indices]

self.img_size = tuple(metainfo.img_size)

self.n_images = len(self.img_paths)

# coarse projected SMPL masks, only for sampling
mask_dir = os.path.join(root, "mask")
self.mask_paths = sorted(glob.glob(f"{mask_dir}/*.png"))
self.mask_paths = [self.mask_paths[i] for i in self.training_indices]

self.shape = np.load(os.path.join(root, "mean_shape.npy"))
self.poses = np.load(os.path.join(root, "poses.npy"))[self.training_indices]
self.trans = np.load(os.path.join(root, "normalize_trans.npy"))[
self.training_indices
]
mask_left_dir = os.path.join(root, "mask", "left")
mask_right_dir = os.path.join(root, "mask", "right")
self.mask_paths = sorted(glob.glob(f"{mask_left_dir}/*.png")) + sorted(
glob.glob(f"{mask_right_dir}/*.png")
)

# self.shape = np.load(os.path.join(root, "mean_shape.npy"))
# self.poses = np.load(os.path.join(root, "poses.npy"))
trans = np.load(os.path.join(root, "normalize_trans.npy"))

# cameras
camera_poses = np.load(os.path.join(root, "camera_pos.npy"))
camera_rotates = np.load(os.path.join(root, "camera_rotate.npy"))

assert camera_poses.shape[0] == camera_rotates.shape[0] == len(img_left_paths)
self.indices = np.concatenate(
[np.arange(camera_poses.shape[0]), np.arange(camera_poses.shape[0])], axis=0
)
camera_poses_right = get_right_camera_pos(camera_poses, camera_rotates)

camera_poses = np.concatenate([camera_poses, camera_poses_right], axis=0)
camera_rotates = np.concatenate([camera_rotates, camera_rotates], axis=0)

max_distance_from_camera_to_artist = np.linalg.norm(
np.concatenate([trans, trans], axis=0).squeeze(1) - camera_poses, axis=-1
).max()

scene_bounding_sphere = 3.0
# self.scale = 1 / scale_mats[0][0, 0]
self.scale = 1
self.scale = 1 / (
max_distance_from_camera_to_artist * 1.1 / scene_bounding_sphere
)

self.camera_poses = torch.tensor(camera_poses, dtype=torch.float32)
self.camera_rotates = torch.tensor(camera_rotates, dtype=torch.float32)

# normalize
# self.trans = self.trans * self.scale
self.camera_poses = self.camera_poses * self.scale

# other properties
self.num_sample = split.num_sample
self.sampling_strategy = "weighted"
Expand All @@ -106,8 +150,8 @@ def __getitem__(self, idx):
mask = cv2.imread(self.mask_paths[idx])
# preprocess: BGR -> Gray -> Mask
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
dilate_kernel = np.ones((20, 20), np.uint8)
mask_for_sampling = cv2.dilate(mask, dilate_kernel)
# dilate_kernel = np.ones((20, 20), np.uint8)
# mask_for_sampling = cv2.dilate(mask, dilate_kernel)

mask = mask > 0
mask_for_sampling = mask > 0
Expand All @@ -120,12 +164,12 @@ def __getitem__(self, idx):
uv = np.stack(np.meshgrid(x, y, indexing="xy"), axis=-1) # (h, w, 2)
uv = flip_y_2d_points(uv, y_range=[-0.5, 0.5])

smpl_params = torch.zeros([86]).float()
smpl_params[0] = torch.from_numpy(np.asarray(self.scale)).float()
# smpl_params = torch.zeros([86]).float()
# smpl_params[0] = torch.from_numpy(np.asarray(self.scale)).float()

smpl_params[1:4] = torch.from_numpy(self.trans[idx]).float()
smpl_params[4:76] = torch.from_numpy(self.poses[idx]).reshape(-1).float()
smpl_params[76:] = torch.from_numpy(self.shape).float()
# smpl_params[1:4] = torch.from_numpy(self.trans[idx]).float()
# smpl_params[4:76] = torch.from_numpy(self.poses[idx]).reshape(-1).float()
# smpl_params[76:] = torch.from_numpy(self.shape).float()

if self.num_sample > 0:
data = {
Expand All @@ -144,9 +188,10 @@ def __getitem__(self, idx):
# "pose": self.pose_all[idx],
"camera_poses": self.camera_poses[idx],
"camera_rotates": self.camera_rotates[idx],
"smpl_params": smpl_params,
# "smpl_params": smpl_params,
# "index_outside": index_outside,
"idx": idx,
"idx": self.indices[idx],
"scale": self.scale,
}
images = {
"rgb": samples["rgb"].astype(np.float32),
Expand All @@ -160,8 +205,9 @@ def __getitem__(self, idx):
# "pose": self.pose_all[idx],
"camera_poses": self.camera_poses[idx],
"camera_rotates": self.camera_rotates[idx],
"smpl_params": smpl_params,
"idx": idx,
# "smpl_params": smpl_params,
"idx": self.indices[idx],
"scale": self.scale,
}
images = {
"rgb": img.reshape(-1, 3).astype(np.float32),
Expand All @@ -180,8 +226,7 @@ def __len__(self):
return 1

def __getitem__(self, idx):
image_id = int(np.random.choice(len(self.dataset), 1))
self.data = self.dataset[image_id]
self.data = self.dataset[idx]
inputs, images = self.data

inputs = {
Expand All @@ -190,9 +235,9 @@ def __getitem__(self, idx):
"camera_rotates": inputs["camera_rotates"],
# "intrinsics": inputs["intrinsics"],
# "pose": inputs["pose"],
"smpl_params": inputs["smpl_params"],
"image_id": image_id,
# "smpl_params": inputs["smpl_params"],
"idx": idx,
"scale": inputs["scale"],
}
images = {
"rgb": images["rgb"],
Expand Down Expand Up @@ -232,10 +277,11 @@ def __getitem__(self, idx):
"uv": uv,
"camera_poses": inputs["camera_poses"],
"camera_rotates": inputs["camera_rotates"],
# "intrinsics": inputs["intrinsics"],
# "pose": inputs["pose"],
"smpl_params": inputs["smpl_params"],
"idx": inputs["idx"],
# "smpl_params": inputs["smpl_params"],
"idx": idx,
"scale": inputs["scale"],
}
images = {
"rgb": images["rgb"],
"mask": images["mask"],
}
Expand Down
92 changes: 59 additions & 33 deletions code/lib/model/deformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,64 @@
from .smpl import SMPLServer
from pytorch3d import ops

class SMPLDeformer():
def __init__(self, max_dist=0.1, K=1, gender='female', betas=None):

class SMPLDeformer:
def __init__(self, smpl: SMPLServer, max_dist=0.1, K=1, betas=None):
super().__init__()

self.max_dist = max_dist
self.K = K
self.smpl = SMPLServer(gender=gender)
self.smpl = smpl
smpl_params_canoical = self.smpl.param_canonical.clone()
smpl_params_canoical[:, 76:] = torch.tensor(betas).float().to(self.smpl.param_canonical.device)
cano_scale, cano_transl, cano_thetas, cano_betas = torch.split(smpl_params_canoical, [1, 3, 72, 10], dim=1)
smpl_params_canoical[:, 76:] = (
torch.tensor(betas).float().to(self.smpl.param_canonical.device)
)
cano_scale, cano_transl, cano_thetas, cano_betas = torch.split(
smpl_params_canoical, [1, 3, 72, 10], dim=1
)
smpl_output = self.smpl(cano_scale, cano_transl, cano_thetas, cano_betas)
self.smpl_verts = smpl_output['smpl_verts']
self.smpl_weights = smpl_output['smpl_weights']
def forward(self, x, smpl_tfs, return_weights=True, inverse=False, smpl_verts=None):
if x.shape[0] == 0: return x
if smpl_verts is None:
weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights)
self.smpl_verts = smpl_output["smpl_verts"]
self.smpl_weights = smpl_output["smpl_weights"]

def forward(
self,
x,
smpl_tfs,
return_weights=True,
inverse=False,
smpl_verts=None,
smpl_weights=None,
):
if x.shape[0] == 0:
return x
if smpl_verts is None or smpl_weights is None:
weights, outlier_mask = self.query_skinning_weights_smpl_multi(
x[None], smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights
)
else:
weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=smpl_verts[0], smpl_weights=self.smpl_weights)
# TODO: check which smpl weights to use
weights, outlier_mask = self.query_skinning_weights_smpl_multi(
x[None], smpl_verts=smpl_verts[0], smpl_weights=smpl_weights
)
if return_weights:
return weights

x_transformed = skinning(x.unsqueeze(0), weights, smpl_tfs, inverse=inverse)

return x_transformed.squeeze(0), outlier_mask

def forward_skinning(self, xc, cond, smpl_tfs):
weights, _ = self.query_skinning_weights_smpl_multi(xc, smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights)
weights, _ = self.query_skinning_weights_smpl_multi(
xc, smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights
)
x_transformed = skinning(xc, weights, smpl_tfs, inverse=False)

return x_transformed

def query_skinning_weights_smpl_multi(self, pts, smpl_verts, smpl_weights):

distance_batch, index_batch, neighbor_points = ops.knn_points(pts, smpl_verts.unsqueeze(0),
K=self.K, return_nn=True)
distance_batch, index_batch, neighbor_points = ops.knn_points(
pts, smpl_verts.unsqueeze(0), K=self.K, return_nn=True
)
distance_batch = torch.clamp(distance_batch, max=4)
weights_conf = torch.exp(-distance_batch)
distance_batch = torch.sqrt(distance_batch)
Expand All @@ -49,27 +72,30 @@ def query_skinning_weights_smpl_multi(self, pts, smpl_verts, smpl_weights):
outlier_mask = (distance_batch[..., 0] > self.max_dist)[0]
return weights, outlier_mask

def query_weights(self, xc):
weights = self.forward(xc, None, return_weights=True, inverse=False)
def query_weights(self, xc, smpl_weights):
weights = self.forward(
xc, None, return_weights=True, inverse=False, smpl_weights=smpl_weights
)
return weights

def forward_skinning_normal(self, xc, normal, cond, tfs, inverse = False):
if normal.ndim == 2:
normal = normal.unsqueeze(0)
w = self.query_weights(xc[0], cond)
# def forward_skinning_normal(self, xc, normal, cond, tfs, inverse=False):
# if normal.ndim == 2:
# normal = normal.unsqueeze(0)
# w = self.query_weights(xc[0], cond)

p_h = F.pad(normal, (0, 1), value=0)
# p_h = F.pad(normal, (0, 1), value=0)

if inverse:
# p:num_point, n:num_bone, i,j: num_dim+1
tf_w = torch.einsum('bpn,bnij->bpij', w.double(), tfs.double())
tf_w_inverse = tf_w.inverse()
# if inverse:
# # p:num_point, n:num_bone, i,j: num_dim+1
# tf_w = torch.einsum("bpn,bnij->bpij", w.double(), tfs.double())
# p_h = torch.einsum("bpij,bpj->bpi", tf_w.inverse(), p_h.double()).float()
# else:
# p_h = torch.einsum(
# "bpn, bnij, bpj->bpi", w.double(), tfs.double(), p_h.double()
# ).float()

# return p_h[:, :, :3]

p_h = torch.einsum('bpij,bpj->bpi', tf_w_inverse, p_h.double()).float()
else:
p_h = torch.einsum('bpn, bnij, bpj->bpi', w.double(), tfs.double(), p_h.double()).float()

return p_h[:, :, :3]

def skinning(x, w, tfs, inverse=False):
"""Linear blend skinning
Expand All @@ -90,4 +116,4 @@ def skinning(x, w, tfs, inverse=False):
x_h = torch.einsum("bpij,bpj->bpi", w_tf_inverse, x_h)
else:
x_h = torch.einsum("bpn,bnij,bpj->bpi", w, tfs, x_h)
return x_h[:, :, :3]
return x_h[:, :, :3]
10 changes: 6 additions & 4 deletions code/lib/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def forward(self, model_outputs, ground_truth):
rgb_gt[nan_filter],
)
eikonal_loss = self.get_eikonal_loss(model_outputs["grad_theta"])
if model_outputs["epoch"] < 20 or model_outputs["epoch"] % 20 == 0:
bce_loss = self.get_bce_loss(model_outputs["acc_map"])
else:
bce_loss = self.get_bce_loss(model_outputs["acc_map"], mask=None)
# if model_outputs["epoch"] < 20 or model_outputs["epoch"] % 20 == 0:
# bce_loss = self.get_bce_loss(model_outputs["acc_map"])
# else:
bce_loss = self.get_bce_loss(
model_outputs["acc_map"], mask=ground_truth["mask"]
)
opacity_sparse_loss = self.get_opacity_sparse_loss(
model_outputs["acc_map"], model_outputs["index_off_surface"]
)
Expand Down
Loading

0 comments on commit 9445f43

Please sign in to comment.