-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support ball_query with cambricon MLU backend and mlu-ops l…
…ibrary. (#2520) * [Feature] Support ball_query with cambricon MLU backend and mlu-ops library. * [Fix] update operator data layout setting. * [Fix] add cxx compile option to avoid symbol conflict. * [Fix] fix lint errors. * [Fix] update ops.md with info of ball_query support by MLU backend. * [Feature] Fix typo. * [Fix] Remove print. * [Fix] get mlu-ops from MMCV_MLU_OPS_PATH env. * [Fix] update MMCV_MLU_OPS_PATH check logic. * [Fix] update error info when failed to download mlu-ops. * [Fix] check mlu-ops version matching info in mmcv. * [Fix] revise wrong filename. * [Fix] remove f.close and re.
- Loading branch information
1 parent
84f60c1
commit dfb0380
Showing
7 changed files
with
336 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/************************************************************************* | ||
* Copyright (C) 2022 Cambricon. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS | ||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | ||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. | ||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY | ||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, | ||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE | ||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
*************************************************************************/ | ||
#include "mlu_common_helper.h" | ||
|
||
void ball_query_forward_mlu(int b, int n, int m, float min_radius, | ||
float max_radius, int nsample, const Tensor new_xyz, | ||
const Tensor xyz, Tensor idx) { | ||
MluOpTensorDescriptor new_xyz_desc, xyz_desc, idx_desc; | ||
new_xyz_desc.set(new_xyz); | ||
xyz_desc.set(xyz); | ||
idx_desc.set(idx); | ||
|
||
auto new_xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( | ||
new_xyz, new_xyz.suggest_memory_format()); | ||
auto xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( | ||
xyz, new_xyz.suggest_memory_format()); | ||
auto idx_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( | ||
idx, new_xyz.suggest_memory_format()); | ||
|
||
auto new_xyz_impl = torch_mlu::getMluTensorImpl(new_xyz_contiguous); | ||
auto xyz_impl = torch_mlu::getMluTensorImpl(xyz_contiguous); | ||
auto idx_impl = torch_mlu::getMluTensorImpl(idx_contiguous); | ||
auto new_xyz_ptr = new_xyz_impl->cnnlMalloc(); | ||
auto xyz_ptr = xyz_impl->cnnlMalloc(); | ||
auto idx_ptr = idx_impl->cnnlMalloc(); | ||
|
||
auto handle = mluOpGetCurrentHandle(); | ||
mluOpBallQuery(handle, new_xyz_desc.desc(), new_xyz_ptr, xyz_desc.desc(), | ||
xyz_ptr, min_radius, max_radius, nsample, idx_desc.desc(), | ||
idx_ptr); | ||
} | ||
|
||
void ball_query_forward_impl(int b, int n, int m, float min_radius, | ||
float max_radius, int nsample, | ||
const Tensor new_xyz, const Tensor xyz, | ||
Tensor idx); | ||
|
||
REGISTER_DEVICE_IMPL(ball_query_forward_impl, MLU, ball_query_forward_mlu); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
/************************************************************************* | ||
* Copyright (C) 2022 Cambricon. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS | ||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | ||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. | ||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY | ||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, | ||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE | ||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
*************************************************************************/ | ||
#include "mlu_common_helper.h" | ||
|
||
// Descriptors | ||
mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type) { | ||
const std::map<std::string, mluOpDataType_t> mapping_type = { | ||
{std::string("c10::Half"), MLUOP_DTYPE_HALF}, | ||
{std::string("float"), MLUOP_DTYPE_FLOAT}, | ||
{std::string("double"), MLUOP_DTYPE_DOUBLE}, | ||
{std::string("int8"), MLUOP_DTYPE_INT8}, | ||
{std::string("signed char"), MLUOP_DTYPE_INT8}, | ||
{std::string("short int"), MLUOP_DTYPE_INT16}, | ||
{std::string("short"), MLUOP_DTYPE_INT16}, | ||
{std::string("int"), MLUOP_DTYPE_INT32}, | ||
{std::string("long int"), MLUOP_DTYPE_INT64}, | ||
{std::string("long"), MLUOP_DTYPE_INT64}, | ||
{std::string("unsigned char"), MLUOP_DTYPE_UINT8}, | ||
{std::string("bool"), MLUOP_DTYPE_BOOL}, | ||
{std::string("c10::complex<c10::Half>"), MLUOP_DTYPE_COMPLEX_HALF}, | ||
{std::string("c10::complex<float>"), MLUOP_DTYPE_COMPLEX_FLOAT}}; | ||
|
||
if (mapping_type.find(std::string(data_type.name())) != mapping_type.end()) { | ||
return mapping_type.find(std::string(data_type.name()))->second; | ||
} | ||
return MLUOP_DTYPE_INVALID; | ||
} | ||
|
||
// laytout | ||
mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input) { | ||
auto suggest_memory_format = input.suggest_memory_format(); | ||
mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY; | ||
switch (input.dim()) { | ||
case 4: | ||
layout = (suggest_memory_format == at::MemoryFormat::ChannelsLast) | ||
? MLUOP_LAYOUT_NHWC | ||
: MLUOP_LAYOUT_NCHW; | ||
break; | ||
case 5: | ||
layout = (suggest_memory_format == at::MemoryFormat::ChannelsLast3d) | ||
? MLUOP_LAYOUT_NDHWC | ||
: MLUOP_LAYOUT_NCDHW; | ||
break; | ||
default: | ||
layout = MLUOP_LAYOUT_ARRAY; | ||
} | ||
return layout; | ||
} | ||
|
||
void MluOpTensorDescriptor::set(Tensor t) { | ||
mluOpDataType_t data_type = getMluOpDataType(t.dtype()); | ||
mluOpTensorLayout_t layout = getMluOpSuggestLayout(t); | ||
int t_dim = t.dim(); | ||
std::vector<int> dim_array; | ||
if (t_dim == 0) { | ||
dim_array.push_back( | ||
1); // ScalarTensor(0-dim 1-item Tensor) view like size = 1 as default; | ||
} else { | ||
for (int i = 0; i < t_dim; i++) { | ||
dim_array.push_back(static_cast<int>(t.sizes().vec()[i])); | ||
} | ||
} | ||
set_desc(t, layout, data_type, dim_array); | ||
} | ||
|
||
void MluOpTensorDescriptor::set_desc(const at::Tensor& t, | ||
mluOpTensorLayout_t layout, | ||
mluOpDataType_t dtype, | ||
std::vector<int>& dims) { | ||
int dimNb = dims.size(); | ||
mluOpSetTensorDescriptor(desc_, layout, dtype, dimNb, dims.data()); | ||
} | ||
|
||
// Handles | ||
std::once_flag mmcv_mluop_init_flag; | ||
std::mutex mmcv_mluop_mutex; | ||
static std::vector<MluOpHandle> mmcv_mluop_handles; | ||
|
||
mluOpHandle_t mluOpGetCurrentHandle(c10::DeviceIndex device_index) { | ||
std::call_once(mmcv_mluop_init_flag, | ||
[]() // Init mmcv_mluop_handles 1-device <-> 1-handle | ||
{ | ||
c10::DeviceIndex num_devices = torch_mlu::device_count(); | ||
mmcv_mluop_handles.resize(num_devices); | ||
}); | ||
|
||
if (device_index == -1) { | ||
device_index = torch_mlu::current_device(); | ||
} | ||
std::lock_guard<std::mutex> mmcv_mluop_guard(mmcv_mluop_mutex); | ||
auto queue = torch_mlu::getCurrentQueue(device_index).queue(); | ||
mmcv_mluop_handles[device_index].setQueue(queue); | ||
return mmcv_mluop_handles[device_index].handle; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/************************************************************************* | ||
* Copyright (C) 2022 Cambricon. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS | ||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | ||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. | ||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY | ||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, | ||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE | ||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | ||
*************************************************************************/ | ||
#pragma once | ||
#include <ATen/ATen.h> | ||
#include <c10/core/ScalarType.h> | ||
|
||
#include "aten.h" | ||
#include "mlu_op.h" | ||
#include "pytorch_device_registry.hpp" | ||
|
||
#define MLUOP_MAJOR 0 | ||
#define MLUOP_MINOR 4 | ||
#define MLUOP_PATCHLEVEL 1 | ||
|
||
mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type); | ||
mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input); | ||
|
||
class MluOpTensorDescriptor { | ||
public: | ||
MluOpTensorDescriptor() { mluOpCreateTensorDescriptor(&desc_); }; | ||
~MluOpTensorDescriptor() { mluOpDestroyTensorDescriptor(desc_); } | ||
|
||
void set(at::Tensor); | ||
mluOpTensorDescriptor_t desc() { return desc_; } | ||
|
||
private: | ||
mluOpTensorDescriptor_t desc_; | ||
void set_desc(const at::Tensor&, mluOpTensorLayout_t, mluOpDataType_t, | ||
std::vector<int>& dims); | ||
}; | ||
|
||
mluOpHandle_t mluOpGetCurrentHandle(c10::DeviceIndex device_index = -1); | ||
|
||
class MluOpHandle { | ||
public: | ||
MluOpHandle() : handle(nullptr) { mluOpCreate(&handle); } | ||
~MluOpHandle() { | ||
if (handle) { | ||
mluOpDestroy(handle); | ||
handle = nullptr; | ||
} | ||
} | ||
void setQueue(cnrtQueue_t queue) { mluOpSetQueue(handle, queue); } | ||
mluOpHandle_t handle; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.