Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
LinZhuoChen committed Aug 10, 2024
1 parent 6c51c2c commit a081d7d
Show file tree
Hide file tree
Showing 11 changed files with 138 additions and 279 deletions.
Binary file added docs/images/framework_new.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions examples/gaussian_splatting/configs/colmap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ trainer:
camera_model:
enable_training: False
renderer:
name: "GaussianSplattingRender"
name: "MsplatRender"
render_depth: True
max_sh_degree: ${trainer.model.point_cloud.max_sh_degree}

Expand Down Expand Up @@ -92,4 +92,4 @@ trainer:

exporter:
# name: TSDFFusion
name: VideoExporter
name: BaseExporter
1 change: 0 additions & 1 deletion pointrix/engine/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ..dataset import parse_data_set
from ..utils.config import parse_structured
from ..optimizer import parse_optimizer, parse_scheduler
from ..exporter.novel_view import test_view_render, novel_view_render
from ..exporter import parse_exporter
from ..densification.gs import DensificationController
from .default_datapipeline import BaseDataPipeline
Expand Down
5 changes: 1 addition & 4 deletions pointrix/engine/default_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from ..dataset import parse_data_set
from ..utils.config import parse_structured
from ..optimizer import parse_optimizer, parse_scheduler
from ..exporter.novel_view import test_view_render, novel_view_render
from ..exporter import parse_exporter
from ..densification.gs import DensificationController
from .default_datapipeline import BaseDataPipeline
Expand Down Expand Up @@ -43,8 +42,6 @@ def train_loop(self) -> None:
self.call_hook("before_train_iter")
# structure of batch {"frame_index": frame_index, "image": image, "depth": depth}
batch = self.datapipeline.next_train(self.global_step)
# update the sh degree of renderer
self.model.renderer.update_sh_degree(iteration)
# update learning rate
self.schedulers.step(self.global_step, self.optimizer)
# model forward step
Expand Down Expand Up @@ -83,7 +80,7 @@ def train_step(self, batch: List[dict]) -> None:
# "rotation": self.point_cloud.get_rotation,
# "shs": self.point_cloud.get_shs,
# }
render_results = self.model(batch)
render_results = self.model(batch, iteration=self.global_step)
# structure of render_results: {}
# example of render_results = {
# "rgb": rgb,
Expand Down
72 changes: 35 additions & 37 deletions pointrix/exporter/base_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from ..engine.default_datapipeline import BaseDataPipeline
from ..model.base_model import BaseModel

from ..logger import ProgressLogger


EXPORTER_REGISTRY = Registry("EXPORTER", modules=["pointrix.exporter"])
EXPORTER_REGISTRY.__doc__ = ""
Expand Down Expand Up @@ -68,48 +70,44 @@ def forward(self, output_path):
output_path : str
The output path to save the images.
"""
l1_test = 0.0
psnr_test = 0.0
ssim_test = 0.0
lpips_test = 0.0
l1 = 0.0
psnr_metric = 0.0
ssim_metric = 0.0
lpips_metric = 0.0
lpips_func = LPIPS()
val_dataset = self.datapipeline.validation_dataset
val_dataset_size = len(val_dataset)
progress_bar = tqdm(
range(0, val_dataset_size),
desc="Validation progress",
leave=False,
)

progress_logger = ProgressLogger(description='Extracting metrics', suffix='iters/s')
progress_logger.add_task(f'Metric', f'Extracting metrics', val_dataset_size)
mkdir_p(os.path.join(output_path, 'test_view'))

for i in range(0, val_dataset_size):
batch = self.datapipeline.next_val(i)
render_results = self.model(batch, training=False)
image_name = os.path.basename(batch[0]['camera'].rgb_file_name)
gt_image = torch.clamp(batch[0]['image'].to("cuda").float(), 0.0, 1.0)
image = torch.clamp(
render_results['rgb'], 0.0, 1.0).squeeze()
visualize_feature = ['rgb']
with progress_logger.progress as progress:
for i in range(0, val_dataset_size):
batch = self.datapipeline.next_val(i)
render_results = self.model(batch, training=False)
image_name = os.path.basename(batch[0]['camera'].rgb_file_name)
gt = torch.clamp(batch[0]['image'].to("cuda").float(), 0.0, 1.0)
image = torch.clamp(
render_results['rgb'], 0.0, 1.0).squeeze()
visualize_feature = ['rgb']

for feat_name in visualize_feature:
feat = render_results[feat_name]
visual_feat = eval(f"visualize_{feat_name}")(feat.squeeze())
if not os.path.exists(os.path.join(output_path, f'test_view_{feat_name}')):
os.makedirs(os.path.join(
output_path, f'test_view_{feat_name}'))
imageio.imwrite(os.path.join(
output_path, f'test_view_{feat_name}', image_name), visual_feat)
for feat_name in visualize_feature:
feat = render_results[feat_name]
visual_feat = eval(f"visualize_{feat_name}")(feat.squeeze())
if not os.path.exists(os.path.join(output_path, f'test_view_{feat_name}')):
os.makedirs(os.path.join(
output_path, f'test_view_{feat_name}'))
imageio.imwrite(os.path.join(
output_path, f'test_view_{feat_name}', image_name), visual_feat)

l1_test += l1_loss(image, gt_image).mean().double()
psnr_test += psnr(image, gt_image).mean().double()
ssim_test += ssim(image, gt_image).mean().double()
lpips_test += lpips_func(image, gt_image).mean().double()
progress_bar.update(1)
progress_bar.close()
l1_test /= val_dataset_size
psnr_test /= val_dataset_size
ssim_test /= val_dataset_size
lpips_test /= val_dataset_size
l1 += l1_loss(image, gt, return_mean=True).double()
psnr_metric += psnr(image, gt).mean().double()
ssim_metric += ssim(image, gt).mean().double()
lpips_metric += lpips_func(image, gt).mean().double()
progress_logger.update(f'Metric', step=1)
l1 /= val_dataset_size
psnr_metric /= val_dataset_size
ssim_metric /= val_dataset_size
lpips_metric /= val_dataset_size
print(
f"Test results: L1 {l1_test:.5f} PSNR {psnr_test:.5f} SSIM {ssim_test:.5f} LPIPS (VGG) {lpips_test:.5f}")
f"Test results: L1 {l1:.5f} PSNR {psnr_metric:.5f} SSIM {ssim_metric:.5f} LPIPS (VGG) {lpips_metric:.5f}")
196 changes: 85 additions & 111 deletions pointrix/exporter/mesh_exporter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import vdbfusion
import random
import numpy as np
from pathlib import Path
Expand All @@ -8,109 +9,104 @@

import torch
import imageio

from ..exporter.base_exporter import EXPORTER_REGISTRY, BaseExporter
from ..logger import ProgressLogger


@EXPORTER_REGISTRY.register()
class TSDFFusion(BaseExporter):
"""
The exporter class for the mesh export using tsdffusion.
modified from https://github.com/maturk/dn-splatter/blob/main/dn_splatter/export_mesh.py
"""

def forward(self, output_dir):
import vdbfusion
output_dir = Path(output_dir)
if not output_dir.exists():
output_dir.mkdir(parents=True)

num_frames = len(self.datapipeline.iter_train_image_dataloader) # type: ignore
samples_per_frame = (self.cfg.total_points + num_frames) // (num_frames)
TSDFvolume = vdbfusion.VDBVolume(
frame_count = len(
self.datapipeline.iter_train_image_dataloader) # type: ignore
samples_per_frame = (self.cfg.total_points +
frame_count) // frame_count
tsdf_volume = vdbfusion.VDBVolume(
voxel_size=self.cfg.voxel_size, sdf_trunc=self.cfg.sdf_truc, space_carving=True
)
points = []
colors = []
point_list = []
color_list = []
self.model.point_cloud.cuda()
with torch.no_grad():
for i, batch_data in enumerate(self.datapipeline.iter_train_image_dataloader):
print(i)
## assume batch size == 1
data = batch_data[0]
camera = data["camera"]
render_results = self.model(batch_data)
# TODO
try:
depth = render_results["depth"].squeeze()
except:
raise ValueError('no depth in render_results,please set config --render_depth as True')
c2w = torch.eye(4, dtype=torch.float, device=depth.device)
c2w[:3, :4] = torch.linalg.inv(camera.extrinsic_matrix)[:3, :4]

# c2w = c2w @ torch.diag(
# torch.tensor([1, -1, -1, 1], device=c2w.device, dtype=torch.float)
# )
c2w = c2w[:3, :4]
H, W = int(camera.image_height), int(camera.image_width)

indices = random.sample(range(H * W), samples_per_frame)

xyzs, rgbs = self.get_colored_points_from_depth(
depths=depth,
rgbs=render_results["rgb"].squeeze().permute(1, 2, 0),
fx=camera.fx,
fy=camera.fy,
cx=camera.cx, # type: ignore
cy=camera.cy, # type: ignore
img_size=(W, H),
c2w=c2w,
mask=indices,
progress_logger = ProgressLogger(
description='Extracting mesh using TSDF', suffix='iters/s')
progress_logger.add_task(f'Mesh', f'Extracting mesh using TSDF', len(
self.datapipeline.iter_train_image_dataloader))
with progress_logger.progress as progress:
for i, batch in enumerate(self.datapipeline.iter_train_image_dataloader):
# Assume batch size == 1
data = batch[0]
camera_info = data["camera"]
render_output = self.model(batch)
try:
depth_map = render_output["depth"].squeeze()
except:
raise ValueError(
'No depth in render_output, please set config trainer.model.renderer.render_depth as True')

camera_to_world = torch.linalg.inv(
camera_info.extrinsic_matrix)[:3, :4]
height, width = int(camera_info.image_height), int(
camera_info.image_width)

sampled_indices = random.sample(
range(height * width), samples_per_frame)

points, colors = self.get_colored_points_from_depth(
depths=depth_map,
rgbs=render_output["rgb"].squeeze().permute(1, 2, 0),
fx=camera_info.fx,
fy=camera_info.fy,
cx=camera_info.cx,
cy=camera_info.cy,
img_size=(width, height),
c2w=camera_to_world,
mask=sampled_indices,
)

point_list.append(points)
color_list.append(colors)
tsdf_volume.integrate(
points.double().cpu().numpy(),
extrinsic=camera_to_world[:3,3].double().cpu().numpy(),
)

progress_logger.update(f'Mesh', step=1)
vertices, faces = tsdf_volume.extract_triangle_mesh(
min_weight=5)

mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(vertices)
mesh.triangles = o3d.utility.Vector3iVector(faces)
mesh.compute_vertex_normals()
colors = torch.cat(color_list, dim=0)
colors = colors.cpu().numpy()
mesh.vertex_colors = o3d.utility.Vector3dVector(colors)

# Simplify mesh
if self.cfg.target_triangles is not None:
mesh = mesh.simplify_quadric_decimation(
self.cfg.target_triangles)

o3d.io.write_triangle_mesh(
str(output_dir / "TSDFfusion_baseline_mesh.ply"),
mesh,
)

points.append(xyzs)
colors.append(rgbs)
TSDFvolume.integrate(
xyzs.double().cpu().numpy(),
extrinsic=c2w[:3, 3].double().cpu().numpy(),
print(
f"Finished computing mesh: {str(output_dir / 'TSDFfusion.ply')}"
)
vertices, faces = TSDFvolume.extract_triangle_mesh(min_weight=5)

mesh = o3d.geometry.TriangleMesh()
mesh.vertices = o3d.utility.Vector3dVector(vertices)
mesh.triangles = o3d.utility.Vector3iVector(faces)
mesh.compute_vertex_normals()
colors = torch.cat(colors, dim=0)
colors = colors.cpu().numpy()
mesh.vertex_colors = o3d.utility.Vector3dVector(colors)

# simplify mesh
if self.cfg.target_triangles is not None:
mesh = mesh.simplify_quadric_decimation(self.cfg.target_triangles)

o3d.io.write_triangle_mesh(
str(output_dir / "TSDFfusion_baseline_mesh.ply"),
mesh,
)
print(
f"Finished computing mesh: {str(output_dir / 'TSDFfusion.ply')}"
)

mesh_clean = self.post_process_mesh(mesh, cluster_to_keep=1)
o3d.io.write_triangle_mesh(
str(output_dir / "TSDFfusion_baseline_mesh_clean.ply"),
mesh_clean,
)


def get_colored_points_from_depth(
self,
depths: Tensor,
rgbs: Tensor,
c2w: Tensor,
fx: float,
fy: float,
cx: int,
cy: int,
img_size: tuple,
mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:

def get_colored_points_from_depth(self, depths, rgbs, c2w, fx, fy, cx, cy, img_size,
mask: Optional[Tensor] = None):
"""Return colored pointclouds from depth and rgb frame and c2w. Optional masking.
Returns:
Expand All @@ -136,30 +132,7 @@ def get_colored_points_from_depth(
colors = rgbs.view(-1, 3)
points = points
return (points, colors)

def post_process_mesh(self, mesh, cluster_to_keep=1000):
"""
Post-process a mesh to filter out floaters and disconnected parts
"""
import copy
print("post processing the mesh to have {} clusterscluster_to_kep".format(cluster_to_keep))
mesh_0 = copy.deepcopy(mesh)
with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles())

triangle_clusters = np.asarray(triangle_clusters)
cluster_n_triangles = np.asarray(cluster_n_triangles)
cluster_area = np.asarray(cluster_area)
n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep]
n_cluster = max(n_cluster, 50) # filter meshes smaller than 50
triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster
mesh_0.remove_triangles_by_mask(triangles_to_remove)
mesh_0.remove_unreferenced_vertices()
mesh_0.remove_degenerate_triangles()
print("num vertices raw {}".format(len(mesh.vertices)))
print("num vertices post {}".format(len(mesh_0.vertices)))
return mesh_0


def get_means3d_backproj(
self,
depths: Tensor,
Expand Down Expand Up @@ -212,9 +185,10 @@ def get_means3d_backproj(
c2w = torch.eye((means3d.shape[0], 4, 4), device=device)

# to world coords
means3d = means3d @ torch.linalg.inv(c2w[..., :3, :3]) + c2w[..., :3, 3]
means3d = means3d @ torch.linalg.inv(
c2w[..., :3, :3]) + c2w[..., :3, 3]
return means3d, image_coords

def get_camera_coords(self, img_size: tuple, pixel_offset: float = 0.5) -> Tensor:
"""Generates camera pixel coordinates [W,H]
Expand Down
Loading

0 comments on commit a081d7d

Please sign in to comment.