Skip to content

Commit

Permalink
Adding a fix and a test for calling rasterize with float64 vertices.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 648732380
  • Loading branch information
The diffren Authors committed Jul 2, 2024
1 parent e588c15 commit 2a57f58
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 16 deletions.
17 changes: 12 additions & 5 deletions diffren/jax/internal/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def rasterize_triangles(vertices: jnp.ndarray,
None of the outputs of this function are differentiable.
Args:
vertices: float32 array of xyz positions with shape [vertex_count, d]. If
vertices: float array of xyz positions with shape [vertex_count, d]. If
projection_matrix is specified, d may be 3 or 4. If camera is None, d must
be 4 and the values are assumed to be xyzw homogenous coordinates.
triangles: int32 array with shape [triangle_count, 3].
triangles: int array with shape [triangle_count, 3].
camera: float array with shape [4, 4] containing a model-view-perspective
projection matrix (later, this will optionally be a camera with distortion
coefficients). May be None. If None, vertices are assumed to be 4D
Expand Down Expand Up @@ -94,9 +94,16 @@ def rasterize_triangles(vertices: jnp.ndarray,
vertices = jnp.matmul(
vertices, jnp.transpose(camera), precision=jax.lax.Precision.HIGHEST)

triangle_id, z_buffer, barycentrics = rasterize_triangles_xla.rasterize_triangles(
vertices, triangles, image_width, image_height, num_layers,
face_culling_mode)
triangle_id, z_buffer, barycentrics = (
rasterize_triangles_xla.rasterize_triangles(
vertices.astype(jnp.float32),
triangles.astype(jnp.int32),
image_width,
image_height,
num_layers,
face_culling_mode,
)
)

mask = (z_buffer != 1.0).astype(jnp.float32)

Expand Down
75 changes: 64 additions & 11 deletions diffren/jax/internal/rasterize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import jax.numpy as jnp
import numpy as np

jax.config.update('jax_enable_x64', True)


class RasterizeTest(chex.TestCase, parameterized.TestCase):

Expand All @@ -46,15 +48,63 @@ def setUp(self):

@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters(
('w constant', [1.0, 1.0, 1.0], 'Simple_Triangle.png', False, False),
('w constant diff barys', [1.0, 1.0, 1.0], 'Simple_Triangle.png', False,
True), ('w constant None camera', [1.0, 1.0, 1.0
], 'Simple_Triangle.png', True, False),
('w varying', [0.2, 0.5, 2.0], 'Perspective_Corrected_Triangle.png',
False, False), ('w varying diff barys', [0.2, 0.5, 2.0],
'Perspective_Corrected_Triangle.png', False, True))
def test_render_simple_triangle(self, w_vector, target_image_name,
use_none_camera, use_diff_barys):
(
'w constant',
[1.0, 1.0, 1.0],
'Simple_Triangle.png',
False,
False,
jnp.float32,
),
(
'w constant float64',
[1.0, 1.0, 1.0],
'Simple_Triangle.png',
False,
False,
jnp.float64,
),
(
'w constant diff barys',
[1.0, 1.0, 1.0],
'Simple_Triangle.png',
False,
True,
jnp.float32,
),
(
'w constant None camera',
[1.0, 1.0, 1.0],
'Simple_Triangle.png',
True,
False,
jnp.float32,
),
(
'w varying',
[0.2, 0.5, 2.0],
'Perspective_Corrected_Triangle.png',
False,
False,
jnp.float32,
),
(
'w varying diff barys',
[0.2, 0.5, 2.0],
'Perspective_Corrected_Triangle.png',
False,
True,
jnp.float32,
),
)
def test_render_simple_triangle(
self,
w_vector,
target_image_name,
use_none_camera,
use_diff_barys,
dtype,
):
"""Directly renders a rasterized triangle's barycentric coordinates.
Tests the wrapping code as well as the kernel.
Expand All @@ -65,12 +115,15 @@ def test_render_simple_triangle(self, w_vector, target_image_name,
use_none_camera: pass in None as the camera transform, or the identity
matrix
use_diff_barys: compute and test differentiable barycentric coordinates.
dtype: the dtype to use for the vertex coordinates.
"""
clip_coordinates = jnp.array(
[[-0.5, -0.5, 0.8, 1.0], [0.0, 0.5, 0.3, 1.0], [0.5, -0.5, 0.3, 1.0]],
dtype=jnp.float32)
dtype=dtype,
)
clip_coordinates = clip_coordinates * jnp.reshape(
jnp.array(w_vector, dtype=jnp.float32), [3, 1])
jnp.array(w_vector, dtype=dtype), [3, 1]
)
camera = None if use_none_camera else jnp.eye(4)

@self.variant
Expand Down

0 comments on commit 2a57f58

Please sign in to comment.