Skip to content

Commit

Permalink
Fix coordinate system conventions in renderer
Browse files Browse the repository at this point in the history
Summary:
## Updates

- Defined the world and camera coordinates according to this figure. The world coordinates are defined as having +Y up, +X left and +Z in.

{F230888499}

- Removed all flipping from blending functions.
- Updated the rasterizer to return images with +Y up and +X left.
- Updated all the mesh rasterizer tests
    - The expected values are now defined in terms of the default +Y up, +X left
    - Added tests where the triangles in the meshes are non symmetrical so that it is clear which direction +X and +Y are

## Questions:
- Should we have **scene settings** instead of raster settings?
    - To be more correct we should be [z clipping in the rasterizer based on the far/near clipping planes](https://github.com/ShichenLiu/SoftRas/blob/master/soft_renderer/cuda/soft_rasterize_cuda_kernel.cu#L400) - these values are also required in the blending functions so should we make these scene level parameters and have a scene settings tuple which is available to the rasterizer and shader?

Reviewed By: gkioxari

Differential Revision: D20208604

fbshipit-source-id: 55787301b1bffa0afa9618f0a0886cc681da51f3
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Mar 6, 2020
1 parent 767d68a commit 15c72be
Show file tree
Hide file tree
Showing 27 changed files with 522 additions and 482 deletions.
Binary file modified docs/notes/assets/transformations_overview.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/notes/assets/world_camera_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 9 additions & 6 deletions docs/notes/renderer_getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,22 @@ The differentiable renderer API is experimental and subject to change!.

### Coordinate transformation conventions

Rendering requires transformations between several different coordinate frames: world space, view/camera space, NDC space and screen space. At each step it is important to know where the camera is located, how the x,y,z axes are aligned and the possible range of values. The following figure outlines the conventions used PyTorch3d.
Rendering requires transformations between several different coordinate frames: world space, view/camera space, NDC space and screen space. At each step it is important to know where the camera is located, how the +X, +Y, +Z axes are aligned and the possible range of values. The following figure outlines the conventions used PyTorch3d.

<img src="assets/transformations_overview.png" width="1000">


For example, given a teapot mesh, the world coordinate frame, camera coordiante frame and image are show in the figure below. Note that the world and camera coordinate frames have the +z direction pointing in to the page.

<img src="assets/world_camera_image.png" width="1000">

---

**NOTE: PyTorch3d vs OpenGL**

While we tried to emulate several aspects of OpenGL, the NDC coordinate system in PyTorch3d is **right-handed** compared with a **left-handed** NDC coordinate system in OpenGL (the projection matrix switches the handedness).

In OpenGL, the camera at the origin is looking along `-z` axis in camera space, but it is looking along the `+z` axis in NDC space.
While we tried to emulate several aspects of OpenGL, there are differences in the coordinate frame conventions.
- The default world coordinate frame in PyTorch3D has +Z pointing in to the screen whereas in OpenGL, +Z is pointing out of the screen. Both are right handed.
- The NDC coordinate system in PyTorch3d is **right-handed** compared with a **left-handed** NDC coordinate system in OpenGL (the projection matrix switches the handedness).

<img align="center" src="assets/opengl_coordframes.png" width="300">

Expand All @@ -60,7 +63,7 @@ A renderer in PyTorch3d is composed of a **rasterizer** and a **shader**. Create
from pytorch3d.renderer import (
OpenGLPerspectiveCameras, look_at_view_transform,
RasterizationSettings, BlendParams,
MeshRenderer, MeshRasterizer, PhongShader
MeshRenderer, MeshRasterizer, HardPhongShader
)
# Initialize an OpenGL perspective camera.
Expand All @@ -81,7 +84,7 @@ raster_settings = RasterizationSettings(
# PhongShader, passing in the device on which to initialize the default parameters
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=PhongShader(device=device, cameras=cameras)
shader=HardPhongShader(device=device, cameras=cameras)
)
```

Expand Down

Large diffs are not rendered by default.

145 changes: 47 additions & 98 deletions docs/tutorials/render_textured_meshes.ipynb

Large diffs are not rendered by default.

77 changes: 46 additions & 31 deletions pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,12 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const float* face_verts,
const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh,
float blur_radius,
bool perspective_correct,
int N,
int H,
int W,
int K,
const float blur_radius,
const bool perspective_correct,
const int N,
const int H,
const int W,
const int K,
int64_t* face_idxs,
float* zbuf,
float* pix_dists,
Expand All @@ -207,8 +207,10 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
// Convert linear index to 3D index
const int n = i / (H * W); // batch index.
const int pix_idx = i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W;

// Determine ordering based on axis convention.
const int yi = H - 1 - pix_idx / W;
const int xi = W - 1 - pix_idx % W;

// screen coordinates to ndc coordiantes of pixel.
const float xf = PixToNdc(xi, W);
Expand Down Expand Up @@ -254,7 +256,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(

// TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size);
int idx = n * H * W * K + yi * H * K + xi * K;
int idx = n * H * W * K + pix_idx * K;
for (int k = 0; k < q_size; ++k) {
face_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z;
Expand All @@ -274,7 +276,7 @@ RasterizeMeshesNaiveCuda(
const int image_size,
const float blur_radius,
const int num_closest,
bool perspective_correct) {
const bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
Expand Down Expand Up @@ -331,12 +333,12 @@ RasterizeMeshesNaiveCuda(
__global__ void RasterizeMeshesBackwardCudaKernel(
const float* face_verts, // (F, 3, 3)
const int64_t* pix_to_face, // (N, H, W, K)
bool perspective_correct,
int N,
int F,
int H,
int W,
int K,
const bool perspective_correct,
const int N,
const int F,
const int H,
const int W,
const int K,
const float* grad_zbuf, // (N, H, W, K)
const float* grad_bary, // (N, H, W, K, 3)
const float* grad_dists, // (N, H, W, K)
Expand All @@ -351,17 +353,20 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
// Convert linear index to 3D index
const int n = t_i / (H * W); // batch index.
const int pix_idx = t_i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W;

// Determine ordering based on axis convention.
const int yi = H - 1 - pix_idx / W;
const int xi = W - 1 - pix_idx % W;

const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float2 pxy = make_float2(xf, yf);

// Loop over all the faces for this pixel.
for (int k = 0; k < K; k++) {
// Index into (N, H, W, K, :) grad tensors
const int i =
n * H * W * K + yi * H * K + xi * K + k; // pixel index + face index
// pixel index + top k index
int i = n * H * W * K + pix_idx * K + k;

const int f = pix_to_face[i];
if (f < 0) {
Expand Down Expand Up @@ -451,7 +456,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
bool perspective_correct) {
const bool perspective_correct) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
Expand Down Expand Up @@ -509,6 +514,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Have each block handle a chunk of faces
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;

for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch; // batch index
const int chunk_idx = chunk % chunks_per_batch;
Expand Down Expand Up @@ -551,17 +557,21 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Y coordinate of the top and bottom of the bin.
// PixToNdc gives the location of the center of each pixel, so we
// need to add/subtract a half pixel to get the true extent of the bin.
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
// Reverse ordering of Y axis so that +Y is upwards in the image.
const int yidx = num_bins - by;
float bin_y_max = PixToNdc(yidx * bin_size - 1, H) + half_pix;
float bin_y_min = PixToNdc((yidx - 1) * bin_size, H) - half_pix;

const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);

for (int bx = 0; bx < num_bins; ++bx) {
// X coordinate of the left and right of the bin.
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
const float bin_x_max =
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
// Reverse ordering of x axis so that +X is left.
const int xidx = num_bins - bx;
float bin_x_max = PixToNdc(xidx * bin_size - 1, W) + half_pix;
float bin_x_min = PixToNdc((xidx - 1) * bin_size, W) - half_pix;

const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) {
binmask.set(by, bx, f);
}
Expand Down Expand Up @@ -654,7 +664,6 @@ torch::Tensor RasterizeMeshesCoarseCuda(
// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************

__global__ void RasterizeMeshesFineCudaKernel(
const float* face_verts, // (F, 3, 3)
const int32_t* bin_faces, // (N, B, B, T)
Expand Down Expand Up @@ -695,8 +704,14 @@ __global__ void RasterizeMeshesFineCudaKernel(

if (yi >= H || xi >= W)
continue;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);

// Reverse ordering of the X and Y axis so that
// in the image +Y is pointing up and +X is pointing left.
const int yidx = H - 1 - yi;
const int xidx = W - 1 - xi;

const float xf = PixToNdc(xidx, W);
const float yf = PixToNdc(yidx, H);
const float2 pxy = make_float2(xf, yf);

// This part looks like the naive rasterization kernel, except we use
Expand Down Expand Up @@ -751,7 +766,7 @@ RasterizeMeshesFineCuda(
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
bool perspective_correct) {
const bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
Expand Down
87 changes: 44 additions & 43 deletions pytorch3d/csrc/rasterize_meshes/rasterize_meshes.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@ RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
bool perspective_correct);
const int image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct);

#ifdef WITH_CUDA
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int num_closest,
bool perspective_correct);
const int image_size,
const float blur_radius,
const int num_closest,
const bool perspective_correct);
#endif
// Forward pass for rasterizing a batch of meshes.
//
Expand Down Expand Up @@ -77,10 +77,10 @@ RasterizeMeshesNaive(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
bool perspective_correct) {
const int image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct) {
// TODO: Better type checking.
if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA
Expand Down Expand Up @@ -117,7 +117,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
bool perspective_correct);
const bool perspective_correct);

#ifdef WITH_CUDA
torch::Tensor RasterizeMeshesBackwardCuda(
Expand All @@ -126,7 +126,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
bool perspective_correct);
const bool perspective_correct);
#endif

// Args:
Expand Down Expand Up @@ -159,7 +159,7 @@ torch::Tensor RasterizeMeshesBackward(
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_bary,
const torch::Tensor& grad_dists,
bool perspective_correct) {
const bool perspective_correct) {
if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA
return RasterizeMeshesBackwardCuda(
Expand Down Expand Up @@ -191,20 +191,20 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin);
const int image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin);

#ifdef WITH_CUDA
torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin);
const int image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin);
#endif
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
Expand Down Expand Up @@ -232,10 +232,10 @@ torch::Tensor RasterizeMeshesCoarse(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin) {
const int image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin) {
if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA
return RasterizeMeshesCoarseCuda(
Expand Down Expand Up @@ -270,11 +270,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFineCuda(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
int image_size,
float blur_radius,
int bin_size,
int faces_per_pixel,
bool perspective_correct);
const int image_size,
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct);
#endif
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
Expand Down Expand Up @@ -317,11 +317,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFine(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
int image_size,
float blur_radius,
int bin_size,
int faces_per_pixel,
bool perspective_correct) {
const int image_size,
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct) {
if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA
return RasterizeMeshesFineCuda(
Expand Down Expand Up @@ -373,6 +373,7 @@ RasterizeMeshesFine(
// this function instead returns screen-space
// barycentric coordinates for each pixel.
//
//
// Returns:
// A 4 element tuple of:
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
Expand All @@ -394,12 +395,12 @@ RasterizeMeshes(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
int bin_size,
int max_faces_per_bin,
bool perspective_correct) {
const int image_size,
const float blur_radius,
const int faces_per_pixel,
const int bin_size,
const int max_faces_per_bin,
const bool perspective_correct) {
if (bin_size > 0 && max_faces_per_bin > 0) {
// Use coarse-to-fine rasterization
auto bin_faces = RasterizeMeshesCoarse(
Expand Down
Loading

0 comments on commit 15c72be

Please sign in to comment.