Skip to content

Commit

Permalink
sample_pdf CUDA and C++ implementations.
Browse files Browse the repository at this point in the history
Summary: Implement the sample_pdf function from the NeRF project as compiled operators.. The binary search (in searchsorted) is replaced with a low tech linear search, but this is not a problem for the envisaged numbers of bins.

Reviewed By: gkioxari

Differential Revision: D26312535

fbshipit-source-id: df1c3119cd63d944380ed1b2657b6ad81d743e49
  • Loading branch information
bottler authored and facebook-github-bot committed Aug 17, 2021
1 parent 7d7d00f commit 1ea2b72
Show file tree
Hide file tree
Showing 7 changed files with 488 additions and 3 deletions.
4 changes: 4 additions & 0 deletions pytorch3d/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "point_mesh/point_mesh_cuda.h"
#include "rasterize_meshes/rasterize_meshes.h"
#include "rasterize_points/rasterize_points.h"
#include "sample_pdf/sample_pdf.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals_forward", &FaceAreasNormalsForward);
Expand Down Expand Up @@ -83,6 +84,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("point_face_array_dist_forward", &PointFaceArrayDistanceForward);
m.def("point_face_array_dist_backward", &PointFaceArrayDistanceBackward);

// Sample PDF
m.def("sample_pdf", &SamplePdf);

// Pulsar.
#ifdef PULSAR_LOGGING_ENABLED
c10::ShowLogInfoToStderr();
Expand Down
153 changes: 153 additions & 0 deletions pytorch3d/csrc/sample_pdf/sample_pdf.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

// There is no intermediate memory, so no reason not to have blocksize=32.
// 256 is a reasonable number of blocks.

// DESIGN
// We exploit the fact that n_samples is not tiny.
// A chunk of work is T*blocksize many samples from
// a single batch elememt.
// For each batch element there will be
// chunks_per_batch = 1 + (n_samples-1)/(T*blocksize) of them.
// The number of potential chunks to do is
// n_chunks = chunks_per_batch * n_batches.
// These chunks are divided among the gridSize-many blocks.
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
// In chunk i, we work on batch_element i/chunks_per_batch
// on samples starting from (i%chunks_per_batch) * (T*blocksize)

// BEGIN HYPOTHETICAL
// Another option (not implemented) if batch_size was always large
// would be as follows.

// A chunk of work is S samples from each of blocksize-many
// batch elements.
// For each batch element there will be
// chunks_per_batch = (1+(n_samples-1)/S) of them.
// The number of potential chunks to do is
// n_chunks = chunks_per_batch * (1+(n_batches-1)/blocksize)
// These chunks are divided among the gridSize-many blocks.
// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
// In chunk i, we work on samples starting from S*(i%chunks_per_batch)
// on batch elements starting from blocksize*(i/chunks_per_batch).
// END HYPOTHETICAL

__global__ void SamplePdfCudaKernel(
const float* __restrict__ bins,
const float* __restrict__ weights,
float* __restrict__ outputs,
float eps,
const int T,
const int64_t batch_size,
const int64_t n_bins,
const int64_t n_samples) {
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * blockDim.x);
const int64_t n_chunks = chunks_per_batch * batch_size;

for (int64_t i_chunk = blockIdx.x; i_chunk < n_chunks; i_chunk += gridDim.x) {
// Loop over the chunks.
int64_t i_batch_element = i_chunk / chunks_per_batch;
int64_t sample_start = (i_chunk % chunks_per_batch) * (T * blockDim.x);
const float* const weight_startp = weights + n_bins * i_batch_element;
const float* const bin_startp = bins + (1 + n_bins) * i_batch_element;

// Each chunk looks at a single batch element, so we do the preprocessing
// which depends on the batch element, namely finding the total weight.
// Idenntical work is being done in sync here by every thread of the block.
float total_weight = eps;
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) {
total_weight += weight_startp[i_bin];
}

float* const output_startp =
outputs + n_samples * i_batch_element + sample_start;

for (int t = 0; t < T; ++t) {
// Loop over T, which is the number of samples each thread makes within
// the chunk.
const int64_t i_sample_within_chunk = threadIdx.x + t * blockDim.x;
if (sample_start + i_sample_within_chunk >= n_samples) {
// Some threads need to exit early because the sample they would
// make is unwanted.
continue;
}
// output_startp[i_sample_within_chunk] contains the quantile we (i.e.
// this thread) are calcvulating.
float uniform = total_weight * output_startp[i_sample_within_chunk];
int64_t i_bin = 0;
// We find the bin containing the quantile by walking along the weights.
// This loop must be thread dependent. I.e. the whole warp will wait until
// every thread has found the bin for its quantile.
// It may be best to write it differently.
while (i_bin + 1 < n_bins && uniform > weight_startp[i_bin]) {
uniform -= weight_startp[i_bin];
++i_bin;
}

// Now we know which bin to look in, we use linear interpolation
// to find the location of the quantile within the bin, and
// write the answer back.
float bin_start = bin_startp[i_bin];
float bin_end = bin_startp[i_bin + 1];
float bin_weight = weight_startp[i_bin];
float output_value = bin_start;
if (uniform > bin_weight) {
output_value = bin_end;
} else if (bin_weight > eps) {
output_value += (uniform / bin_weight) * (bin_end - bin_start);
}
output_startp[i_sample_within_chunk] = output_value;
}
}
}

void SamplePdfCuda(
const at::Tensor& bins,
const at::Tensor& weights,
const at::Tensor& outputs,
float eps) {
// Check inputs are on the same device
at::TensorArg bins_t{bins, "bins", 1}, weights_t{weights, "weights", 2},
outputs_t{outputs, "outputs", 3};
at::CheckedFrom c = "SamplePdfCuda";
at::checkAllSameGPU(c, {bins_t, weights_t, outputs_t});
at::checkAllSameType(c, {bins_t, weights_t, outputs_t});

// Set the device for the kernel launch based on the device of the input
at::cuda::CUDAGuard device_guard(bins.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

const int64_t batch_size = bins.size(0);
const int64_t n_bins = weights.size(1);
const int64_t n_samples = outputs.size(1);

const int64_t threads = 32;
const int64_t T = n_samples <= threads ? 1 : 2;
const int64_t chunks_per_batch = 1 + (n_samples - 1) / (T * threads);
const int64_t n_chunks = chunks_per_batch * batch_size;

const int64_t max_blocks = 1024;
const int64_t blocks = n_chunks < max_blocks ? n_chunks : max_blocks;

SamplePdfCudaKernel<<<blocks, threads, 0, stream>>>(
bins.contiguous().data_ptr<float>(),
weights.contiguous().data_ptr<float>(),
outputs.data_ptr<float>(), // Checked contiguous in header file.
eps,
T,
batch_size,
n_bins,
n_samples);

AT_CUDA_CHECK(cudaGetLastError());
}
74 changes: 74 additions & 0 deletions pytorch3d/csrc/sample_pdf/sample_pdf.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "utils/pytorch3d_cutils.h"

// ****************************************************************************
// * SamplePdf *
// ****************************************************************************

// Samples a probability density functions defined by bin edges `bins` and
// the non-negative per-bin probabilities `weights`.

// Args:
// bins: FloatTensor of shape `(batch_size, n_bins+1)` denoting the edges
// of the sampling bins.

// weights: FloatTensor of shape `(batch_size, n_bins)` containing
// non-negative numbers representing the probability of sampling the
// corresponding bin.

// uniforms: The quantiles to draw, FloatTensor of shape
// `(batch_size, n_samples)`.

// outputs: On call, this contains the quantiles to draw. It is overwritten
// with the drawn samples. FloatTensor of shape
// `(batch_size, n_samples), where `n_samples are drawn from each
// distribution.

// eps: A constant preventing division by zero in case empty bins are
// present.

// Not differentiable

#ifdef WITH_CUDA
void SamplePdfCuda(
const torch::Tensor& bins,
const torch::Tensor& weights,
const torch::Tensor& outputs,
float eps);
#endif

void SamplePdfCpu(
const torch::Tensor& bins,
const torch::Tensor& weights,
const torch::Tensor& outputs,
float eps);

inline void SamplePdf(
const torch::Tensor& bins,
const torch::Tensor& weights,
const torch::Tensor& outputs,
float eps) {
if (bins.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(weights);
CHECK_CONTIGUOUS_CUDA(outputs);
SamplePdfCuda(bins, weights, outputs, eps);
return;
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CONTIGUOUS(outputs);
SamplePdfCpu(bins, weights, outputs, eps);
}
141 changes: 141 additions & 0 deletions pytorch3d/csrc/sample_pdf/sample_pdf_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <torch/extension.h>
#include <algorithm>
#include <thread>
#include <vector>

// If the number of bins is the typical 64, it is
// quicker to use binary search than linear scan.
// With more bins, it is more important.
// There is no equivalent CUDA implementation yet.
#define USE_BINARY_SEARCH

namespace {
// This worker function does the job of SamplePdf but only on
// batch elements in [start_batch, end_batch).
void SamplePdfCpu_worker(
const torch::Tensor& bins,
const torch::Tensor& weights,
const torch::Tensor& outputs,
float eps,
int64_t start_batch,
int64_t end_batch) {
const int64_t n_bins = weights.size(1);
const int64_t n_samples = outputs.size(1);

auto bins_a = bins.accessor<float, 2>();
auto weights_a = weights.accessor<float, 2>();
float* __restrict__ output_p =
outputs.data_ptr<float>() + start_batch * n_samples;

#ifdef USE_BINARY_SEARCH
std::vector<float> partial_sums(n_bins);
#endif

for (int64_t i_batch_elt = start_batch; i_batch_elt < end_batch;
++i_batch_elt) {
auto bin_a = bins_a[i_batch_elt];
auto weight_a = weights_a[i_batch_elt];

// Here we do the work which has to be done once per batch element.
// i.e. (1) finding the total weight. (2) If using binary search,
// precompute the partial sums of the weights.

float total_weight = 0;
for (int64_t i_bin = 0; i_bin < n_bins; ++i_bin) {
total_weight += weight_a[i_bin];
#ifdef USE_BINARY_SEARCH
partial_sums[i_bin] = total_weight;
#endif
}
total_weight += eps;

for (int64_t i_sample = 0; i_sample < n_samples; ++i_sample) {
// Here we are taking a single random quantile (which is stored
// in *output_p) and using it to make a single sample, which we
// write back to the same location. First we find which bin
// the quantile lives in, either by binary search in the
// precomputed partial sums, or by scanning through the weights.

float uniform = total_weight * *output_p;
#ifdef USE_BINARY_SEARCH
int64_t i_bin = std::lower_bound(
partial_sums.begin(), --partial_sums.end(), uniform) -
partial_sums.begin();
if (i_bin > 0) {
uniform -= partial_sums[i_bin - 1];
}
#else
int64_t i_bin = 0;
while (i_bin + 1 < n_bins && uniform > weight_a[i_bin]) {
uniform -= weight_a[i_bin];
++i_bin;
}
#endif

// Now i_bin identifies the bin the quantile lives in, we use
// straight line interpolation to find the position of the
// quantile within the bin, and write it to *output_p.

float bin_start = bin_a[i_bin];
float bin_end = bin_a[i_bin + 1];
float bin_weight = weight_a[i_bin];
float output_value = bin_start;
if (uniform > bin_weight) {
output_value = bin_end;
} else if (bin_weight > eps) {
output_value += (uniform / bin_weight) * (bin_end - bin_start);
}
*output_p = output_value;
++output_p;
}
}
}

} // anonymous namespace

void SamplePdfCpu(
const torch::Tensor& bins,
const torch::Tensor& weights,
const torch::Tensor& outputs,
float eps) {
const int64_t batch_size = bins.size(0);
const int64_t max_threads = std::min(4, at::get_num_threads());
const int64_t n_threads = std::min(max_threads, batch_size);
if (batch_size == 0) {
return;
}

// SamplePdfCpu_worker does the work of this function. We send separate ranges
// of batch elements to that function in nThreads-1 separate threads.

std::vector<std::thread> threads;
threads.reserve(n_threads - 1);
const int64_t batch_elements_per_thread = 1 + (batch_size - 1) / n_threads;
int64_t start_batch = 0;
for (int iThread = 0; iThread < n_threads - 1; ++iThread) {
threads.emplace_back(
SamplePdfCpu_worker,
bins,
weights,
outputs,
eps,
start_batch,
start_batch + batch_elements_per_thread);
start_batch += batch_elements_per_thread;
}

// The remaining batch elements are calculated in this threads. If nThreads is
// 1 then all the work happens in this line.
SamplePdfCpu_worker(bins, weights, outputs, eps, start_batch, batch_size);
for (auto&& thread : threads) {
thread.join();
}
}
Loading

0 comments on commit 1ea2b72

Please sign in to comment.