Skip to content

Commit

Permalink
Merge pull request #619 from mlverse/contrib/sort-vertices
Browse files Browse the repository at this point in the history
Add CUDA code for sort_vertices.
  • Loading branch information
dfalbel authored Jul 27, 2021
2 parents 5e6395c + e9800a2 commit 38e5547
Show file tree
Hide file tree
Showing 19 changed files with 546 additions and 6 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Collate:
'compat-purrr.R'
'compilation_unit.R'
'conditions.R'
'contrib.R'
'creation-ops.R'
'cuda.R'
'device.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ export(autograd_set_grad_mode)
export(backends_mkl_is_available)
export(backends_mkldnn_is_available)
export(backends_openmp_is_available)
export(contrib_sort_vertices)
export(cuda_current_device)
export(cuda_device_count)
export(cuda_is_available)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- Added Linear Algebra functions (#612)
- Fixed a bug when using a `.getbatch` method that didn't return a `torch_tensor`. (#615)
- Fixed warning when using `%/%` caused by a call to deprecated `torch_floor_divide` (#616)
- Added `contrib_sort_vertices` to efficiently sort vertices on CUDA. (#619)

# torch 0.4.0

Expand Down
4 changes: 4 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ cpp_contrib_torch_sparsemax <- function(input, dim) {
.Call('_torch_cpp_contrib_torch_sparsemax', PACKAGE = 'torchpkg', input, dim)
}

cpp_contrib_torch_sort_vertices <- function(vertices, mask, num_valid) {
.Call('_torch_cpp_contrib_torch_sort_vertices', PACKAGE = 'torchpkg', vertices, mask, num_valid)
}

cpp_cuda_is_available <- function() {
.Call('_torch_cpp_cuda_is_available', PACKAGE = 'torchpkg')
}
Expand Down
24 changes: 24 additions & 0 deletions R/contrib.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#' Contrib sort vertices
#'
#' Based on the implementation from [Rotated_IoU](https://github.com/lilanxiao/Rotated_IoU)
#'
#' @note This function does not make part of the official torch API.
#' @details All tensors should be on a CUDA device so this function can be used.
#'
#' @param vertices A Tensor with the vertices.
#' @param mask A tensors containing the masks.
#' @param num_valid A integer tensors.
#'
#' @examples
#' if (cuda_is_available()) {
#' v <- torch_randn(8, 1024, 24, 2)$cuda()
#' mean <- torch_mean(v, dim=2, keepdim=TRUE)
#' v <- v - mean
#' m <- (torch_rand(8, 1024, 24) > 0.8)$cuda()
#' nv <- torch_sum(m$to(dtype = torch_int()), dim=-1)$to(dtype = torch_int())$cuda()
#' result <- contrib_sort_vertices(v, m, nv)
#' }
#' @export
contrib_sort_vertices <- function(vertices, mask, num_valid) {
cpp_contrib_torch_sort_vertices(vertices, mask, num_valid)
}
11 changes: 10 additions & 1 deletion lantern/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,19 @@ if(DEFINED ENV{CUDA})
set(LANTERN_SRC
${LANTERN_SRC}
src/AllocatorCuda.cpp
src/Contrib/SortVertices/sort_vert_kernel.cu
src/Contrib/SortVertices/sort_vert.cpp
)

cuda_add_library(lantern SHARED ${LANTERN_SRC})
else()
set(LANTERN_SRC
${LANTERN_SRC}
src/Contrib/SortVertices/sort_vert_cpu.cpp
)
add_library(lantern SHARED ${LANTERN_SRC})
endif()

add_library(lantern SHARED ${LANTERN_SRC})
add_library(lantern::library ALIAS lantern)

target_include_directories(lantern PUBLIC
Expand Down
9 changes: 9 additions & 0 deletions lantern/include/lantern/lantern.h
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,14 @@ HOST_API void lantern_vector_Scalar_delete (void* x)

}

LANTERN_API void* (LANTERN_PTR _lantern_contrib_sort_vertices) (void* vertices, void* mask, void* num_valid);
HOST_API void* lantern_contrib_sort_vertices(void* vertices, void* mask, void* num_valid)
{
void* ret = _lantern_contrib_sort_vertices(vertices, mask, num_valid);
LANTERN_HOST_HANDLER;
return ret;
}

/* Autogen Headers -- Start */
LANTERN_API void* (LANTERN_PTR _lantern__cast_byte_tensor_bool)(void* self, void* non_blocking);
HOST_API void* lantern__cast_byte_tensor_bool(void* self, void* non_blocking) { void* ret = _lantern__cast_byte_tensor_bool(self, non_blocking); LANTERN_HOST_HANDLER return ret; }
Expand Down Expand Up @@ -7449,6 +7457,7 @@ LOAD_SYMBOL(_lantern_vector_Scalar_push_back);
LOAD_SYMBOL(_lantern_vector_Scalar_size);
LOAD_SYMBOL(_lantern_vector_Scalar_at);
LOAD_SYMBOL(_lantern_vector_Scalar_delete);
LOAD_SYMBOL(_lantern_contrib_sort_vertices);
/* Autogen Symbols -- Start */
LOAD_SYMBOL(_lantern__cast_byte_tensor_bool)
LOAD_SYMBOL(_lantern__cast_char_tensor_bool)
Expand Down
24 changes: 24 additions & 0 deletions lantern/src/Contrib/SortVertices/README
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Files in this directory are vendored from https://github.com/lilanxiao/Rotated_IoU
They are licensed under MIT:

MIT License

Copyright (c) 2020 Lanxiao Li

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
63 changes: 63 additions & 0 deletions lantern/src/Contrib/SortVertices/cuda_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
MIT License
Copyright (c) 2020 Lanxiao Li
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

#ifndef _CUDA_UTILS_H
#define _CUDA_UTILS_H

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cmath>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>

#define TOTAL_THREADS 512

inline int opt_n_thread(int work_size){
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
return max(min(1<<pow_2, TOTAL_THREADS), 1);
}

inline dim3 opt_block_config(int x, int y){
const int x_thread = opt_n_thread(x);
const int y_thread = max(min(opt_n_thread(y), TOTAL_THREADS/x_thread), 1);
dim3 block_config(x_thread, y_thread, 1);

return block_config;
}

# define CUDA_CHECK_ERRORS() \
do { \
cudaError_t err = cudaGetLastError(); \
if (cudaSuccess!=err){ \
fprintf(stderr, "CUDA kernel failed : %s\n%s at L:%d in %s\n", \
cudaGetErrorString(err), __PRETTY_FUNCTION__, __LINE__, \
__FILE__); \
exit(-1); \
} \
} while(0) \

#endif
67 changes: 67 additions & 0 deletions lantern/src/Contrib/SortVertices/sort_vert.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
MIT License
Copyright (c) 2020 Lanxiao Li
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

#define LANTERN_BUILD
#include "lantern/lantern.h"
#include "utils.h"
#include "sort_vert.h"
#include <torch/torch.h>
#include "../../utils.hpp"

void sort_vertices_wrapper(int b, int n, int m, const float *vertices, const bool *mask, const int *num_valid, int* idx);

at::Tensor sort_vertices(at::Tensor vertices, at::Tensor mask, at::Tensor num_valid){
CHECK_CONTIGUOUS(vertices);
CHECK_CONTIGUOUS(mask);
CHECK_CONTIGUOUS(num_valid);
CHECK_CUDA(vertices);
CHECK_CUDA(mask);
CHECK_CUDA(num_valid);
CHECK_IS_FLOAT(vertices);
CHECK_IS_BOOL(mask);
CHECK_IS_INT(num_valid);

int b = vertices.size(0);
int n = vertices.size(1);
int m = vertices.size(2);
at::Tensor idx = torch::zeros({b, n, MAX_NUM_VERT_IDX},
at::device(vertices.device()).dtype(at::ScalarType::Int));

sort_vertices_wrapper(b, n, m, vertices.data_ptr<float>(), mask.data_ptr<bool>(),
num_valid.data_ptr<int>(), idx.data_ptr<int>());

return idx;
}

void* _lantern_contrib_sort_vertices (void* vertices, void* mask, void* num_valid)
{
torch::Tensor result = sort_vertices(
reinterpret_cast<LanternObject<torch::Tensor>*>(vertices)->get(),
reinterpret_cast<LanternObject<torch::Tensor>*>(mask)->get(),
reinterpret_cast<LanternObject<torch::Tensor>*>(num_valid)->get()
);
return (void*) new LanternObject<torch::Tensor>(result);
}
30 changes: 30 additions & 0 deletions lantern/src/Contrib/SortVertices/sort_vert.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
MIT License
Copyright (c) 2020 Lanxiao Li
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

#pragma once
#define MAX_NUM_VERT_IDX 9

at::Tensor sort_vertices(at::Tensor vertices, at::Tensor mask, at::Tensor num_valid);
12 changes: 12 additions & 0 deletions lantern/src/Contrib/SortVertices/sort_vert_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#define LANTERN_BUILD
#include "lantern/lantern.h"
#include <torch/torch.h>
#include "../../utils.hpp"


void* _lantern_contrib_sort_vertices (void* vertices, void* mask, void* num_valid)
{
LANTERN_FUNCTION_START
throw std::runtime_error("`sort_vertices` is only supported on CUDA runtimes.");
LANTERN_FUNCTION_END
}
Loading

0 comments on commit 38e5547

Please sign in to comment.