Skip to content

Commit

Permalink
renderer surface norm init
Browse files Browse the repository at this point in the history
add normal rerun viz
  • Loading branch information
esli999 committed May 31, 2024
1 parent 77bd39e commit f8d14d2
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 2 deletions.
2 changes: 1 addition & 1 deletion b3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,4 @@ def rerun_visualize_trace_t(trace, t, modes=["rgb", "depth", "inliers"]):
pose.apply(vertices),
colors=(attributes * 255).astype(jnp.uint8),
),
)
)
47 changes: 47 additions & 0 deletions b3d/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,53 @@ def camera_from_position_and_target(
rotation_matrix = jnp.hstack([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)])
return Pose(position, Rot.from_matrix(rotation_matrix).as_quat())

def rotation_from_axis_angle(axis, angle):
"""Creates a rotation matrix from an axis and angle.
Args:
axis (jnp.ndarray): The axis vector. Shape (3,)
angle (float): The angle in radians.
Returns:
jnp.ndarray: The rotation matrix. Shape (3, 3)
"""
sina = jnp.sin(angle)
cosa = jnp.cos(angle)
direction = axis / jnp.linalg.norm(axis)
# rotation matrix around unit vector
R = jnp.diag(jnp.array([cosa, cosa, cosa]))
R = R + jnp.outer(direction, direction) * (1.0 - cosa)
direction = direction * sina
R = R + jnp.array(
[
[0.0, -direction[2], direction[1]],
[direction[2], 0.0, -direction[0]],
[-direction[1], direction[0], 0.0],
]
)
return R

def from_rot(rotation):
"""Creates a pose matrix from a rotation matrix.
Args:
rotation (jnp.ndarray): The rotation matrix. Shape (3, 3)
Returns:
Pose object
"""
return Pose.from_matrix(jnp.vstack(
[jnp.hstack([rotation, jnp.zeros((3, 1))]), jnp.array([0.0, 0.0, 0.0, 1.0])]
))

def from_axis_angle(axis, angle):
"""Creates a pose matrix from an axis and angle.
Args:
axis (jnp.ndarray): The axis vector. Shape (3,)
angle (float): The angle in radians.
Returns:
Pose object
"""
return from_rot(rotation_from_axis_angle(axis, angle))

@register_pytree_node_class
class Pose:
Expand Down
89 changes: 89 additions & 0 deletions b3d/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,95 @@ def render_attribute(self, pose, vertices, faces, ranges, attributes):
return image[0], zs[0]


def render_attribute_normal_many(self, poses, vertices, faces, ranges, attributes):
"""
Render many scenes to an image by rasterizing and then interpolating attributes.
Parameters:
poses: float array, shape (num_scenes, num_objectsß, 4, 4)
Object pose matrix.
vertices: float array, shape (num_vertices, 3)
Vertex position matrix.
faces: int array, shape (num_triangles, 3)
Faces Triangle matrix. The integers ßcorrespond to rows in the vertices matrix.
ranges: int array, shape (num_objects, 2)
Ranges matrix with the 2 elements specify start indices and counts into faces.
attributes: float array, shape (num_vertices, num_attributes)
Attributes corresponding to the vertices
Outputs:
image: float array, shape (num_scenes, height, width, num_attributes)
At each pixel the value is the barycentric interpolation of the attributes corresponding to the
3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect
any triangle the value at that pixel will be 0s.
zs: float array, shape (num_scenes, height, width)
Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle.
norm_im: approximate surface normal image (num_scenes, height, width, 3)
"""
uvs, object_ids, triangle_ids, zs = self.rasterize_many(
poses, vertices, faces, ranges
)
mask = object_ids > 0

interpolated_values = self.interpolate_many(
attributes, uvs, triangle_ids, faces
)
image = interpolated_values * mask[..., None]

def apply_pose(pose, points):
return pose.apply(points)

pose_apply_map = jax.vmap(apply_pose, (0,None))
new_vertices = pose_apply_map(poses, vertices[faces])

def normal_vec(x,y,z):
vec = jnp.cross(y - x, z - x)
norm_vec = vec / jnp.linalg.norm(vec)
return norm_vec

normal_vec_vmap = jax.vmap(jax.vmap(normal_vec, (0,0,0)))
nvecs = normal_vec_vmap(new_vertices[...,0,:], new_vertices[...,1,:], new_vertices[...,2,:])
norm_vecs = jnp.concatenate((jnp.zeros((len(nvecs),1,3)), nvecs),axis=1)

def indexer(transformed_normals, triangle_ids):
return transformed_normals[triangle_ids]

index_map = jax.vmap(indexer, (0,0))
norm_im = index_map(norm_vecs, triangle_ids)

return image, zs, norm_im

def render_attribute_normal(self, pose, vertices, faces, ranges, attributes):
"""
Render a single scenes to an image by rasterizing and then interpolating attributes.
Parameters:
poses: float array, shape (num_objects, 4, 4)
Object pose matrix.
vertices: float array, shape (num_vertices, 3)
Vertex position matrix.
faces: int array, shape (num_triangles, 3)
Faces Triangle matrix. The integers correspond to rows in the vertices matrix.
ranges: int array, shape (num_objects, 2)
Ranges matrix with the 2 elements specify start indices and counts into faces.
attributes: float array, shape (num_vertices, num_attributes)
Attributes corresponding to the vertices
Outputs:
image: float array, shape (height, width, num_attributes)
At each pixel the value is the barycentric interpolation of the attributes corresponding to the
3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect
any triangle the value at that pixel will be 0s.
zs: float array, shape (height, width)
Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle.
norm_im: approximate surface normal image (height, width, 3)
"""
image, zs, norm_im = self.render_attribute_normal_many(
pose[None, ...], vertices, faces, ranges, attributes
)
return image[0], zs[0], norm_im[0]


# XLA array layout in memory
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
Expand Down
16 changes: 16 additions & 0 deletions b3d/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,22 @@ def update_choices_get_score(trace, key, addr_const, *values):
enumerate_choices_get_scores, static_argnums=(2,)
)

def unproject_depth(depth, renderer):
"""Unprojects a depth image into a point cloud.
Args:
depth (jnp.ndarray): The depth image. Shape (H, W)
intrinsics (b.camera.Intrinsics): The camera intrinsics.
Returns:
jnp.ndarray: The point cloud. Shape (H, W, 3)
"""
mask = (depth < renderer.far) * (depth > renderer.near)
depth = depth * mask + renderer.far * (1.0 - mask)
y, x = jnp.mgrid[: depth.shape[0], : depth.shape[1]]
x = (x - renderer.cx) / renderer.fx
y = (y - renderer.cy) / renderer.fy
point_cloud_image = jnp.stack([x, y, jnp.ones_like(x)], axis=-1) * depth[:, :, None]
return point_cloud_image

def nn_background_segmentation(images):
import torch
Expand Down
86 changes: 86 additions & 0 deletions test/test_likelihood_invariances.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,89 @@ def test_distance_to_camera_invarance(renderer):

assert jnp.isclose(near_score, far_score, rtol=0.03)

def test_patch_orientation_invariance(renderer):

object_library = b3d.MeshLibrary.make_empty_library()
occluder = trimesh.creation.box(extents=jnp.array([0.0001, 0.1, 0.1]))
occluder_colors = jnp.tile(jnp.array([0.8, 0.8, 0.8])[None,...], (occluder.vertices.shape[0], 1))
object_library = b3d.MeshLibrary.make_empty_library()
object_library.add_object(occluder.vertices, occluder.faces, attributes=occluder_colors)

image_width = 200
image_height = 200
fx = 200.0
fy = 200.0
cx = 100.0
cy = 100.0
near = 0.001
far = 16.0
renderer.set_intrinsics(image_width, image_height, fx, fy, cx, cy, near, far)

flat_pose = b3d.Pose.from_position_and_target(
jnp.array([0.3, 0.0, 0.0]), jnp.array([0.0, 0.0, 0.0]), jnp.array([0.0, 0.0, 1.0])
).inv()

from b3d.pose import from_axis_angle

transform_vec = jax.vmap(from_axis_angle, (None, 0))
in_place_rots = transform_vec(jnp.array([0,0,1]), jnp.linspace(0, jnp.pi/4, 10))
tilt_pose = flat_pose @ in_place_rots[5]

rgb_flat, depth_flat = renderer.render_attribute(
flat_pose[None, ...],
object_library.vertices,
object_library.faces,
object_library.ranges,
object_library.attributes,
)

rgb_tilt, depth_tilt = renderer.render_attribute(
tilt_pose[None, ...],
object_library.vertices,
object_library.faces,
object_library.ranges,
object_library.attributes,
)


color_error, depth_error = (50.0, 0.01)
inlier_score, outlier_prob = (4.0, 0.000001)
color_multiplier, depth_multiplier = (100.0, 1.0)
model_args = b3d.ModelArgs(
color_error,
depth_error,
inlier_score,
outlier_prob,
color_multiplier,
depth_multiplier,
)

from genjax.generative_functions.distributions import ExactDensity
import genjax


rr.log("img_near", rr.Image(rgb_flat))
rr.log("img_far", rr.Image(rgb_tilt))



area_flat = ((depth_flat / fx) * (depth_flat / fy)).sum()
area_tilt = ((depth_tilt / fx) * (depth_tilt / fy)).sum()
print(area_flat, area_tilt)

flat_score = (
b3d.rgbd_sensor_model.logpdf(
(rgb_flat, depth_flat), rgb_flat, depth_flat, model_args, fx, fy, 0.0
)
)

tilt_score = (
b3d.rgbd_sensor_model.logpdf(
(rgb_tilt, depth_tilt), rgb_tilt, depth_tilt, model_args, fx, fy, 0.0
)
)
print(flat_score, tilt_score)
print(b3d.normalize_log_scores(jnp.array([flat_score, tilt_score])))

assert jnp.isclose(flat_score, tilt_score, rtol=0.05)

36 changes: 35 additions & 1 deletion test/test_render_ycb_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import jax.numpy as jnp
import trimesh
import b3d
import rerun as rr

PORT = 8812
rr.init("real")
rr.connect(addr=f"127.0.0.1:{PORT}")

def test_renderer_full(renderer):
mesh_path = os.path.join(
Expand All @@ -15,7 +19,7 @@ def test_renderer_full(renderer):
object_library.add_trimesh(mesh)

pose = b3d.Pose.from_position_and_target(
jnp.array([0.2, 0.2, 0.0]), jnp.array([0.0, 0.0, 0.0])
jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0])
).inv()

rgb, depth = renderer.render_attribute(
Expand All @@ -27,3 +31,33 @@ def test_renderer_full(renderer):
)
b3d.get_rgb_pil_image(rgb).save(b3d.get_root_path() / "assets/test_results/test_ycb.png")
assert rgb.sum() > 0

def test_renderer_normal_full(renderer):
mesh_path = os.path.join(
b3d.get_root_path(),
"assets/shared_data_bucket/ycb_video_models/models/003_cracker_box/textured_simple.obj",
)
mesh = trimesh.load(mesh_path)

object_library = b3d.MeshLibrary.make_empty_library()
object_library.add_trimesh(mesh)

pose = b3d.Pose.from_position_and_target(
jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0])
).inv()

rgb, depth, normal = renderer.render_attribute_normal(
pose[None, ...],
object_library.vertices,
object_library.faces,
jnp.array([[0, len(object_library.faces)]]),
object_library.attributes,
)

b3d.get_rgb_pil_image((normal+1)/2).save(b3d.get_root_path() / "assets/test_results/test_ycb_normal.png")

point_im = b3d.utils.unproject_depth(depth, renderer)
rr.log("pc", rr.Points3D(point_im.reshape(-1,3), colors=rgb.reshape(-1,3)))
rr.log("arrows", rr.Arrows3D(origins=point_im[::5,::5,:].reshape(-1,3), vectors=normal[::5,::5,:].reshape(-1,3)/100))

assert jnp.abs(normal).sum() > 0

0 comments on commit f8d14d2

Please sign in to comment.