diff --git a/b3d/model.py b/b3d/model.py index 54a888aa..4d0d00b3 100644 --- a/b3d/model.py +++ b/b3d/model.py @@ -200,4 +200,4 @@ def rerun_visualize_trace_t(trace, t, modes=["rgb", "depth", "inliers"]): pose.apply(vertices), colors=(attributes * 255).astype(jnp.uint8), ), - ) \ No newline at end of file + ) diff --git a/b3d/pose.py b/b3d/pose.py index c736cf89..1388bc88 100644 --- a/b3d/pose.py +++ b/b3d/pose.py @@ -90,6 +90,53 @@ def camera_from_position_and_target( rotation_matrix = jnp.hstack([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)]) return Pose(position, Rot.from_matrix(rotation_matrix).as_quat()) +def rotation_from_axis_angle(axis, angle): + """Creates a rotation matrix from an axis and angle. + + Args: + axis (jnp.ndarray): The axis vector. Shape (3,) + angle (float): The angle in radians. + Returns: + jnp.ndarray: The rotation matrix. Shape (3, 3) + """ + sina = jnp.sin(angle) + cosa = jnp.cos(angle) + direction = axis / jnp.linalg.norm(axis) + # rotation matrix around unit vector + R = jnp.diag(jnp.array([cosa, cosa, cosa])) + R = R + jnp.outer(direction, direction) * (1.0 - cosa) + direction = direction * sina + R = R + jnp.array( + [ + [0.0, -direction[2], direction[1]], + [direction[2], 0.0, -direction[0]], + [-direction[1], direction[0], 0.0], + ] + ) + return R + +def from_rot(rotation): + """Creates a pose matrix from a rotation matrix. + + Args: + rotation (jnp.ndarray): The rotation matrix. Shape (3, 3) + Returns: + Pose object + """ + return Pose.from_matrix(jnp.vstack( + [jnp.hstack([rotation, jnp.zeros((3, 1))]), jnp.array([0.0, 0.0, 0.0, 1.0])] + )) + +def from_axis_angle(axis, angle): + """Creates a pose matrix from an axis and angle. + + Args: + axis (jnp.ndarray): The axis vector. Shape (3,) + angle (float): The angle in radians. + Returns: + Pose object + """ + return from_rot(rotation_from_axis_angle(axis, angle)) @register_pytree_node_class class Pose: diff --git a/b3d/renderer.py b/b3d/renderer.py index d60eab02..d1a2013a 100644 --- a/b3d/renderer.py +++ b/b3d/renderer.py @@ -276,6 +276,95 @@ def render_attribute(self, pose, vertices, faces, ranges, attributes): return image[0], zs[0] + def render_attribute_normal_many(self, poses, vertices, faces, ranges, attributes): + """ + Render many scenes to an image by rasterizing and then interpolating attributes. + + Parameters: + poses: float array, shape (num_scenes, num_objectsß, 4, 4) + Object pose matrix. + vertices: float array, shape (num_vertices, 3) + Vertex position matrix. + faces: int array, shape (num_triangles, 3) + Faces Triangle matrix. The integers ßcorrespond to rows in the vertices matrix. + ranges: int array, shape (num_objects, 2) + Ranges matrix with the 2 elements specify start indices and counts into faces. + attributes: float array, shape (num_vertices, num_attributes) + Attributes corresponding to the vertices + + Outputs: + image: float array, shape (num_scenes, height, width, num_attributes) + At each pixel the value is the barycentric interpolation of the attributes corresponding to the + 3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect + any triangle the value at that pixel will be 0s. + zs: float array, shape (num_scenes, height, width) + Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle. + norm_im: approximate surface normal image (num_scenes, height, width, 3) + """ + uvs, object_ids, triangle_ids, zs = self.rasterize_many( + poses, vertices, faces, ranges + ) + mask = object_ids > 0 + + interpolated_values = self.interpolate_many( + attributes, uvs, triangle_ids, faces + ) + image = interpolated_values * mask[..., None] + + def apply_pose(pose, points): + return pose.apply(points) + + pose_apply_map = jax.vmap(apply_pose, (0,None)) + new_vertices = pose_apply_map(poses, vertices[faces]) + + def normal_vec(x,y,z): + vec = jnp.cross(y - x, z - x) + norm_vec = vec / jnp.linalg.norm(vec) + return norm_vec + + normal_vec_vmap = jax.vmap(jax.vmap(normal_vec, (0,0,0))) + nvecs = normal_vec_vmap(new_vertices[...,0,:], new_vertices[...,1,:], new_vertices[...,2,:]) + norm_vecs = jnp.concatenate((jnp.zeros((len(nvecs),1,3)), nvecs),axis=1) + + def indexer(transformed_normals, triangle_ids): + return transformed_normals[triangle_ids] + + index_map = jax.vmap(indexer, (0,0)) + norm_im = index_map(norm_vecs, triangle_ids) + + return image, zs, norm_im + + def render_attribute_normal(self, pose, vertices, faces, ranges, attributes): + """ + Render a single scenes to an image by rasterizing and then interpolating attributes. + + Parameters: + poses: float array, shape (num_objects, 4, 4) + Object pose matrix. + vertices: float array, shape (num_vertices, 3) + Vertex position matrix. + faces: int array, shape (num_triangles, 3) + Faces Triangle matrix. The integers correspond to rows in the vertices matrix. + ranges: int array, shape (num_objects, 2) + Ranges matrix with the 2 elements specify start indices and counts into faces. + attributes: float array, shape (num_vertices, num_attributes) + Attributes corresponding to the vertices + + Outputs: + image: float array, shape (height, width, num_attributes) + At each pixel the value is the barycentric interpolation of the attributes corresponding to the + 3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect + any triangle the value at that pixel will be 0s. + zs: float array, shape (height, width) + Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle. + norm_im: approximate surface normal image (height, width, 3) + """ + image, zs, norm_im = self.render_attribute_normal_many( + pose[None, ...], vertices, faces, ranges, attributes + ) + return image[0], zs[0], norm_im[0] + + # XLA array layout in memory def default_layouts(*shapes): return [range(len(shape) - 1, -1, -1) for shape in shapes] diff --git a/b3d/utils.py b/b3d/utils.py index 4bc2c950..9200fa6b 100644 --- a/b3d/utils.py +++ b/b3d/utils.py @@ -395,6 +395,22 @@ def update_choices_get_score(trace, key, addr_const, *values): enumerate_choices_get_scores, static_argnums=(2,) ) +def unproject_depth(depth, renderer): + """Unprojects a depth image into a point cloud. + + Args: + depth (jnp.ndarray): The depth image. Shape (H, W) + intrinsics (b.camera.Intrinsics): The camera intrinsics. + Returns: + jnp.ndarray: The point cloud. Shape (H, W, 3) + """ + mask = (depth < renderer.far) * (depth > renderer.near) + depth = depth * mask + renderer.far * (1.0 - mask) + y, x = jnp.mgrid[: depth.shape[0], : depth.shape[1]] + x = (x - renderer.cx) / renderer.fx + y = (y - renderer.cy) / renderer.fy + point_cloud_image = jnp.stack([x, y, jnp.ones_like(x)], axis=-1) * depth[:, :, None] + return point_cloud_image def nn_background_segmentation(images): import torch diff --git a/test/test_likelihood_invariances.py b/test/test_likelihood_invariances.py index 3a3b2933..79f483b3 100644 --- a/test/test_likelihood_invariances.py +++ b/test/test_likelihood_invariances.py @@ -169,3 +169,89 @@ def test_distance_to_camera_invarance(renderer): assert jnp.isclose(near_score, far_score, rtol=0.03) +def test_patch_orientation_invariance(renderer): + + object_library = b3d.MeshLibrary.make_empty_library() + occluder = trimesh.creation.box(extents=jnp.array([0.0001, 0.1, 0.1])) + occluder_colors = jnp.tile(jnp.array([0.8, 0.8, 0.8])[None,...], (occluder.vertices.shape[0], 1)) + object_library = b3d.MeshLibrary.make_empty_library() + object_library.add_object(occluder.vertices, occluder.faces, attributes=occluder_colors) + + image_width = 200 + image_height = 200 + fx = 200.0 + fy = 200.0 + cx = 100.0 + cy = 100.0 + near = 0.001 + far = 16.0 + renderer.set_intrinsics(image_width, image_height, fx, fy, cx, cy, near, far) + + flat_pose = b3d.Pose.from_position_and_target( + jnp.array([0.3, 0.0, 0.0]), jnp.array([0.0, 0.0, 0.0]), jnp.array([0.0, 0.0, 1.0]) + ).inv() + + from b3d.pose import from_axis_angle + + transform_vec = jax.vmap(from_axis_angle, (None, 0)) + in_place_rots = transform_vec(jnp.array([0,0,1]), jnp.linspace(0, jnp.pi/4, 10)) + tilt_pose = flat_pose @ in_place_rots[5] + + rgb_flat, depth_flat = renderer.render_attribute( + flat_pose[None, ...], + object_library.vertices, + object_library.faces, + object_library.ranges, + object_library.attributes, + ) + + rgb_tilt, depth_tilt = renderer.render_attribute( + tilt_pose[None, ...], + object_library.vertices, + object_library.faces, + object_library.ranges, + object_library.attributes, + ) + + + color_error, depth_error = (50.0, 0.01) + inlier_score, outlier_prob = (4.0, 0.000001) + color_multiplier, depth_multiplier = (100.0, 1.0) + model_args = b3d.ModelArgs( + color_error, + depth_error, + inlier_score, + outlier_prob, + color_multiplier, + depth_multiplier, + ) + + from genjax.generative_functions.distributions import ExactDensity + import genjax + + + rr.log("img_near", rr.Image(rgb_flat)) + rr.log("img_far", rr.Image(rgb_tilt)) + + + + area_flat = ((depth_flat / fx) * (depth_flat / fy)).sum() + area_tilt = ((depth_tilt / fx) * (depth_tilt / fy)).sum() + print(area_flat, area_tilt) + + flat_score = ( + b3d.rgbd_sensor_model.logpdf( + (rgb_flat, depth_flat), rgb_flat, depth_flat, model_args, fx, fy, 0.0 + ) + ) + + tilt_score = ( + b3d.rgbd_sensor_model.logpdf( + (rgb_tilt, depth_tilt), rgb_tilt, depth_tilt, model_args, fx, fy, 0.0 + ) + ) + print(flat_score, tilt_score) + print(b3d.normalize_log_scores(jnp.array([flat_score, tilt_score]))) + + assert jnp.isclose(flat_score, tilt_score, rtol=0.05) + diff --git a/test/test_render_ycb_model.py b/test/test_render_ycb_model.py index e1f043cf..4c4a50a1 100644 --- a/test/test_render_ycb_model.py +++ b/test/test_render_ycb_model.py @@ -2,7 +2,11 @@ import jax.numpy as jnp import trimesh import b3d +import rerun as rr +PORT = 8812 +rr.init("real") +rr.connect(addr=f"127.0.0.1:{PORT}") def test_renderer_full(renderer): mesh_path = os.path.join( @@ -15,7 +19,7 @@ def test_renderer_full(renderer): object_library.add_trimesh(mesh) pose = b3d.Pose.from_position_and_target( - jnp.array([0.2, 0.2, 0.0]), jnp.array([0.0, 0.0, 0.0]) + jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0]) ).inv() rgb, depth = renderer.render_attribute( @@ -27,3 +31,33 @@ def test_renderer_full(renderer): ) b3d.get_rgb_pil_image(rgb).save(b3d.get_root_path() / "assets/test_results/test_ycb.png") assert rgb.sum() > 0 + +def test_renderer_normal_full(renderer): + mesh_path = os.path.join( + b3d.get_root_path(), + "assets/shared_data_bucket/ycb_video_models/models/003_cracker_box/textured_simple.obj", + ) + mesh = trimesh.load(mesh_path) + + object_library = b3d.MeshLibrary.make_empty_library() + object_library.add_trimesh(mesh) + + pose = b3d.Pose.from_position_and_target( + jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0]) + ).inv() + + rgb, depth, normal = renderer.render_attribute_normal( + pose[None, ...], + object_library.vertices, + object_library.faces, + jnp.array([[0, len(object_library.faces)]]), + object_library.attributes, + ) + + b3d.get_rgb_pil_image((normal+1)/2).save(b3d.get_root_path() / "assets/test_results/test_ycb_normal.png") + + point_im = b3d.utils.unproject_depth(depth, renderer) + rr.log("pc", rr.Points3D(point_im.reshape(-1,3), colors=rgb.reshape(-1,3))) + rr.log("arrows", rr.Arrows3D(origins=point_im[::5,::5,:].reshape(-1,3), vectors=normal[::5,::5,:].reshape(-1,3)/100)) + + assert jnp.abs(normal).sum() > 0