From 13ff2a36543bd7067326a80dc9825906a3faa949 Mon Sep 17 00:00:00 2001 From: jiangyuhao Date: Mon, 28 Aug 2023 16:27:22 +0800 Subject: [PATCH] [Refactor] Replace tin_shift op of MLU backend with mlu-ops --- .../csrc/common/mlu/tin_shift_mlu_kernel.mlu | 307 ------------------ .../pytorch/mlu/focal_loss_sigmoid_mlu.cpp | 28 +- mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h | 4 +- .../csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp | 26 +- mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp | 155 +++------ 5 files changed, 76 insertions(+), 444 deletions(-) delete mode 100644 mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu diff --git a/mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu deleted file mode 100644 index ed64c2b68c..0000000000 --- a/mmcv/ops/csrc/common/mlu/tin_shift_mlu_kernel.mlu +++ /dev/null @@ -1,307 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#include "common_mlu_helper.hpp" - -__nram__ char data_nram[MAX_NRAM_SIZE]; - -template -__mlu_func__ void mluMultiKernelTinShift( - const T *input, const int *shifts, T *output, const int batch_size, - const int time_size, const int channel_size, const int hw_size, - const int group_size, const int group_channel) { - for (int cur_channel_index = taskId; - cur_channel_index < batch_size * channel_size; - cur_channel_index += taskDim) { - int n_index = cur_channel_index / channel_size; - int group_id = cur_channel_index % channel_size / group_channel; - int t_shift = shifts[n_index * group_size + group_id]; - int index = cur_channel_index % channel_size * hw_size + - n_index * time_size * channel_size * hw_size; - __bang_write_value(data_nram, MAX_NRAM_SIZE, (char)0); - __asm__ volatile("sync;"); - if (abs(t_shift) >= time_size) { - __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, - channel_size * hw_size * sizeof(T), hw_size * sizeof(T), - time_size - 1); - } else { - if (t_shift > 0) { - __memcpy(data_nram + t_shift * hw_size * sizeof(T), input + index, - hw_size * sizeof(T), GDRAM2NRAM, hw_size * sizeof(T), - channel_size * hw_size * sizeof(T), time_size - 1 - t_shift); - __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, - channel_size * hw_size * sizeof(T), hw_size * sizeof(T), - time_size - 1); - } else { - __memcpy(data_nram, input + (index - t_shift * channel_size * hw_size), - hw_size * sizeof(T), GDRAM2NRAM, hw_size * sizeof(T), - channel_size * hw_size * sizeof(T), time_size - 1 + t_shift); - __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, - channel_size * hw_size * sizeof(T), hw_size * sizeof(T), - time_size - 1); - } - } - __asm__ volatile("sync;"); - } -} - -template -__mlu_func__ void mluHwSplit(const T *input, const int t_shift, - const int time_size, const int hw_size, - const int channel_size, const int index, - const int cur_sequence_index, - const int max_length_per_core, T *output) { - for (int cur_index = index; cur_index < index + hw_size; - cur_index += max_length_per_core) { - int memcpy_size = max_length_per_core; - if (cur_index + max_length_per_core > index + hw_size) { - memcpy_size = index + hw_size - cur_index; - } - if (cur_sequence_index - t_shift < 0 || - cur_sequence_index - t_shift >= time_size) { - __memcpy(output + cur_index, data_nram, memcpy_size * sizeof(T), - NRAM2GDRAM); - } else { - __memcpy(data_nram, input + cur_index - t_shift * channel_size * hw_size, - memcpy_size * sizeof(T), GDRAM2NRAM); - __memcpy(output + cur_index, data_nram, memcpy_size * sizeof(T), - NRAM2GDRAM); - } - __asm__ volatile("sync;"); - } -} - -template -__mlu_func__ void mluMultiKernelTinShiftSplitSequence( - const T *input, const int *shifts, T *output, const int batch_size, - const int time_size, const int channel_size, const int hw_size, - const int group_size, const int group_channel, - const int max_number_hw_per_core, const int max_length_per_core) { - const int tmp_max_number_hw_per_core = - max_number_hw_per_core > 0 ? max_number_hw_per_core : 1; - const int loop_time = time_size / tmp_max_number_hw_per_core + - ((time_size % tmp_max_number_hw_per_core) > 0 ? 1 : 0); - int segmentime_size = tmp_max_number_hw_per_core; - int res_segment = time_size % tmp_max_number_hw_per_core; - - for (int cur_segment_index = taskId; - cur_segment_index < loop_time * batch_size * channel_size; - cur_segment_index += taskDim) { - int n_index = cur_segment_index / loop_time / channel_size; - int group_id = cur_segment_index / loop_time % channel_size / group_channel; - int t_shift = shifts[n_index * group_size + group_id]; - int index = n_index * time_size * channel_size * hw_size + - (cur_segment_index / loop_time % channel_size) * hw_size + - cur_segment_index % loop_time * segmentime_size * hw_size * - channel_size; - char *dst_gdram2nram = data_nram; - const T *src_gdram2nram = input + index; - int count_gdram2nram = -1; - int count_nram2gdram = -1; - int next_sequence_index = - index / hw_size / channel_size % time_size + segmentime_size; - int cur_sequence_index = index / hw_size / channel_size % time_size; - __bang_write_value(data_nram, MAX_NRAM_SIZE, (char)0); - __asm__ volatile("sync;"); - if (max_number_hw_per_core == 0) { - mluHwSplit(input, t_shift, time_size, hw_size, channel_size, index, - cur_sequence_index, max_length_per_core, output); - continue; - } - if (abs(t_shift) >= time_size) { - if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) { - __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, - channel_size * hw_size * sizeof(T), hw_size * sizeof(T), - res_segment - 1); - } else { - __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, - channel_size * hw_size * sizeof(T), hw_size * sizeof(T), - segmentime_size - 1); - } - continue; - } - if (t_shift == 0) { - if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) { - dst_gdram2nram = data_nram; - src_gdram2nram = input + index; - count_gdram2nram = res_segment - 1; - count_nram2gdram = res_segment - 1; - } else { - dst_gdram2nram = data_nram; - src_gdram2nram = input + index; - count_gdram2nram = segmentime_size - 1; - count_nram2gdram = segmentime_size - 1; - } - } else if (t_shift > 0) { - int first_index_cur_channel = - n_index * time_size * channel_size * hw_size + - (cur_segment_index / loop_time % channel_size) * hw_size; - if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) { - dst_gdram2nram = data_nram; - src_gdram2nram = - input + - (index - t_shift * channel_size * hw_size < first_index_cur_channel - ? first_index_cur_channel - : index - t_shift * channel_size * hw_size); - count_gdram2nram = res_segment - 1; - count_nram2gdram = res_segment - 1; - if (cur_sequence_index < t_shift && t_shift < next_sequence_index) { - dst_gdram2nram = - data_nram + t_shift % segmentime_size * hw_size * sizeof(T); - count_gdram2nram = res_segment - (t_shift - cur_sequence_index) - 1; - } - } else { - if (t_shift >= next_sequence_index) { - __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, - channel_size * hw_size * sizeof(T), hw_size * sizeof(T), - segmentime_size - 1); - continue; - } else if (cur_sequence_index < t_shift && - t_shift < next_sequence_index) { - dst_gdram2nram = - data_nram + t_shift % segmentime_size * hw_size * sizeof(T); - src_gdram2nram = input + first_index_cur_channel; - count_gdram2nram = segmentime_size - (t_shift % segmentime_size) - 1; - count_nram2gdram = segmentime_size - 1; - } else { - dst_gdram2nram = data_nram; - src_gdram2nram = input + index - t_shift * channel_size * hw_size; - count_gdram2nram = segmentime_size - 1; - count_nram2gdram = segmentime_size - 1; - } - } - } else { - int offset_index = time_size + t_shift; - if (cur_sequence_index >= offset_index) { - if ((cur_segment_index + 1) % loop_time == 0 && res_segment != 0) { - __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, - channel_size * hw_size * sizeof(T), hw_size * sizeof(T), - res_segment - 1); - continue; - } else { - __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, - channel_size * hw_size * sizeof(T), hw_size * sizeof(T), - segmentime_size - 1); - continue; - } - } else { - dst_gdram2nram = data_nram; - src_gdram2nram = input + index - t_shift * channel_size * hw_size; - if (cur_sequence_index - t_shift + segmentime_size < time_size) { - count_gdram2nram = segmentime_size - 1; - count_nram2gdram = segmentime_size - 1; - } else { - count_gdram2nram = time_size - (cur_sequence_index - t_shift) - 1; - count_nram2gdram = - (segmentime_size - 1) < (time_size - cur_sequence_index - 1) - ? (segmentime_size - 1) - : (time_size - cur_sequence_index - 1); - } - } - } - __memcpy(dst_gdram2nram, src_gdram2nram, hw_size * sizeof(T), GDRAM2NRAM, - hw_size * sizeof(T), channel_size * hw_size * sizeof(T), - count_gdram2nram); - __memcpy(output + index, data_nram, hw_size * sizeof(T), NRAM2GDRAM, - channel_size * hw_size * sizeof(T), hw_size * sizeof(T), - count_nram2gdram); - __asm__ volatile("sync;"); - } -} - -__mlu_entry__ void MLUUnion1KernelTinShift( - const void *input, const void *shifts, void *output, const int batch_size, - const int time_size, const int channel_size, const int hw_size, - const int group_size, const int group_channel, - const cnrtDataType_t data_dtype) { - // make sure that memcore is not used - if (coreId == 0x80) { - return; - } - switch (data_dtype) { - case CNRT_FLOAT16: { - mluMultiKernelTinShift((half *)input, (const int *)shifts, (half *)output, - batch_size, time_size, channel_size, hw_size, - group_size, group_channel); - }; break; - case CNRT_FLOAT32: { - mluMultiKernelTinShift((float *)input, (const int *)shifts, - (float *)output, batch_size, time_size, - channel_size, hw_size, group_size, group_channel); - }; break; - default: { return; } - } -} - -__mlu_entry__ void MLUUnion1KernelTinShiftSplitSequence( - const void *input, const void *shifts, void *output, const int batch_size, - const int time_size, const int channel_size, const int hw_size, - const int group_size, const int group_channel, - const int max_number_hw_per_core, const int max_length_per_core, - const cnrtDataType_t data_dtype) { - // make sure that memcore is not used - if (coreId == 0x80) { - return; - } - switch (data_dtype) { - case CNRT_FLOAT16: { - mluMultiKernelTinShiftSplitSequence( - (half *)input, (const int *)shifts, (half *)output, batch_size, - time_size, channel_size, hw_size, group_size, group_channel, - max_number_hw_per_core, max_length_per_core); - }; break; - case CNRT_FLOAT32: { - mluMultiKernelTinShiftSplitSequence( - (float *)input, (const int *)shifts, (float *)output, batch_size, - time_size, channel_size, hw_size, group_size, group_channel, - max_number_hw_per_core, max_length_per_core); - }; break; - default: { return; } - } -} - -void KernelTinShiftForward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const void *input, const void *shifts, void *output, const int batch_size, - const int time_size, const int channel_size, const int hw_size, - const int group_size, const int group_channel, - const cnrtDataType_t data_dtype, const int channel_per_core, - const int max_number_hw_per_core, const int max_length_per_core) { - if (channel_per_core >= 1) { - MLUUnion1KernelTinShift<<>>( - input, shifts, output, batch_size, time_size, channel_size, hw_size, - group_size, group_channel, data_dtype); - } else { - MLUUnion1KernelTinShiftSplitSequence<<>>( - input, shifts, output, batch_size, time_size, channel_size, hw_size, - group_size, group_channel, max_number_hw_per_core, max_length_per_core, - data_dtype); - } -} - -void KernelTinShiftBackward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const void *grad_output, const void *shifts, void *grad_input, - const int batch_size, const int time_size, const int channel_size, - const int hw_size, const int group_size, const int group_channel, - const cnrtDataType_t data_dtype, const int channel_per_core, - const int max_number_hw_per_core, const int max_length_per_core) { - if (channel_per_core >= 1) { - MLUUnion1KernelTinShift<<>>( - grad_output, shifts, grad_input, batch_size, time_size, channel_size, - hw_size, group_size, group_channel, data_dtype); - } else { - MLUUnion1KernelTinShiftSplitSequence<<>>( - grad_output, shifts, grad_input, batch_size, time_size, channel_size, - hw_size, group_size, group_channel, max_number_hw_per_core, - max_length_per_core, data_dtype); - } -} diff --git a/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp index ff3c931737..b5633b6df5 100644 --- a/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/focal_loss_sigmoid_mlu.cpp @@ -14,9 +14,9 @@ #include "mlu_common_helper.h" -void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target, - Tensor weight, Tensor output, - const float gamma, const float alpha) { +void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target, Tensor weight, + Tensor output, const float gamma, + const float alpha) { // params check TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ", "But now gamma is ", gamma, "."); @@ -82,15 +82,15 @@ void sigmoid_focal_loss_forward_mlu(Tensor input, Tensor target, auto handle = mluOpGetCurrentHandle(); // launch kernel - TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidForward(handle, prefer, reduction, input_desc.desc(), - input_ptr, target_desc.desc(), target_ptr, - weight_desc.desc(), weight_ptr, alpha, gamma, - output_desc.desc(), output_ptr)); + TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidForward( + handle, prefer, reduction, input_desc.desc(), input_ptr, + target_desc.desc(), target_ptr, weight_desc.desc(), weight_ptr, alpha, + gamma, output_desc.desc(), output_ptr)); } -void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target, - Tensor weight, Tensor output, - const float gamma, const float alpha) { +void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target, Tensor weight, + Tensor output, const float gamma, + const float alpha) { // params check TORCH_CHECK(gamma >= 0, "gamma should be greater than or equal to 0. ", "But now gamma is ", gamma, "."); @@ -158,10 +158,10 @@ void sigmoid_focal_loss_backward_mlu(Tensor input, Tensor target, auto handle = mluOpGetCurrentHandle(); // launch kernel - TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidBackward(handle, prefer, reduction, input_desc.desc(), - input_ptr, target_desc.desc(), target_ptr, - weight_desc.desc(), weight_ptr, alpha, gamma, - output_desc.desc(), output_ptr)); + TORCH_MLUOP_CHECK(mluOpFocalLossSigmoidBackward( + handle, prefer, reduction, input_desc.desc(), input_ptr, + target_desc.desc(), target_ptr, weight_desc.desc(), weight_ptr, alpha, + gamma, output_desc.desc(), output_ptr)); } void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, diff --git a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h index 362ea33564..e4eb2259bb 100644 --- a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h +++ b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h @@ -18,8 +18,8 @@ #include "pytorch_device_registry.hpp" #define MLUOP_MAJOR 0 -#define MLUOP_MINOR 7 -#define MLUOP_PATCHLEVEL 1 +#define MLUOP_MINOR 8 +#define MLUOP_PATCHLEVEL 0 /************************************************************************* * This MACRO contains operations of simple tensor to mlu-tensor. diff --git a/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp index a5cfba0ca9..109745c659 100644 --- a/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/roipoint_pool3d_mlu.cpp @@ -74,8 +74,8 @@ void RoIPointPool3dForwardMLUKernelLauncher( pts_feature.numel(), "."); // set contiguous - auto xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( - xyz, xyz.suggest_memory_format()); + auto xyz_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(xyz, xyz.suggest_memory_format()); auto pts_feature_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( pts_feature, pts_feature.suggest_memory_format()); auto boxes3d_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( @@ -92,13 +92,16 @@ void RoIPointPool3dForwardMLUKernelLauncher( auto pts_feature_ptr = pts_feature_impl->cnnlMalloc(); auto boxes3d_impl = torch_mlu::getMluTensorImpl(boxes3d_contiguous); auto boxes3d_ptr = boxes3d_impl->cnnlMalloc(); - auto pooled_features_impl = torch_mlu::getMluTensorImpl(pooled_features_contiguous); + auto pooled_features_impl = + torch_mlu::getMluTensorImpl(pooled_features_contiguous); auto pooled_features_ptr = pooled_features_impl->cnnlMalloc(); - auto pooled_empty_flag_impl = torch_mlu::getMluTensorImpl(pooled_empty_flag_contiguous); + auto pooled_empty_flag_impl = + torch_mlu::getMluTensorImpl(pooled_empty_flag_contiguous); auto pooled_empty_flag_ptr = pooled_empty_flag_impl->cnnlMalloc(); // create tensor descriptors - MluOpTensorDescriptor xyz_desc, pts_feature_desc, boxes3d_desc, pooled_features_desc, pooled_empty_flag_desc; + MluOpTensorDescriptor xyz_desc, pts_feature_desc, boxes3d_desc, + pooled_features_desc, pooled_empty_flag_desc; xyz_desc.set(xyz_contiguous); pts_feature_desc.set(pts_feature_contiguous); boxes3d_desc.set(boxes3d_contiguous); @@ -108,10 +111,11 @@ void RoIPointPool3dForwardMLUKernelLauncher( // get workspace size_t workspace_size = 0; auto handle = mluOpGetCurrentHandle(); - TORCH_MLUOP_CHECK(mluOpGetRoiPointPool3dWorkspaceSize(handle, batch_size, - pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz_desc.desc(), - pts_feature_desc.desc(), boxes3d_desc.desc(), pooled_features_desc.desc(), - pooled_empty_flag_desc.desc(), &workspace_size)); + TORCH_MLUOP_CHECK(mluOpGetRoiPointPool3dWorkspaceSize( + handle, batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, + xyz_desc.desc(), pts_feature_desc.desc(), boxes3d_desc.desc(), + pooled_features_desc.desc(), pooled_empty_flag_desc.desc(), + &workspace_size)); auto workspace = at::empty(workspace_size, xyz.options().dtype(at::kByte)); auto workspace_impl = torch_mlu::getMluTensorImpl(workspace); @@ -120,8 +124,8 @@ void RoIPointPool3dForwardMLUKernelLauncher( handle, batch_size, pts_num, boxes_num, feature_in_len, sampled_pts_num, xyz_desc.desc(), xyz_ptr, pts_feature_desc.desc(), pts_feature_ptr, boxes3d_desc.desc(), boxes3d_ptr, workspace_ptr, workspace_size, - pooled_features_desc.desc(), pooled_features_ptr, pooled_empty_flag_desc.desc(), - (int *)pooled_empty_flag_ptr)); + pooled_features_desc.desc(), pooled_features_ptr, + pooled_empty_flag_desc.desc(), (int *)pooled_empty_flag_ptr)); } void roipoint_pool3d_forward_mlu(int batch_size, int pts_num, int boxes_num, diff --git a/mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp index 728330795d..6b7714f599 100644 --- a/mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/tin_shift_mlu.cpp @@ -9,65 +9,7 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" - -void KernelTinShiftForward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const void *input, const void *shifts, void *output, const int batch_size, - const int time_size, const int channel_size, const int hw_size, - const int group_size, const int group_channel, - const cnrtDataType_t data_dtype, const int channel_per_core, - const int max_number_hw_per_core, const int max_length_per_core); - -void KernelTinShiftBackward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const void *grad_output, const void *shifts, void *grad_input, - const int batch_size, const int time_size, const int channel_size, - const int hw_size, const int group_size, const int group_channel, - const cnrtDataType_t data_dtype, const int channel_per_core, - const int max_number_hw_per_core, const int max_length_per_core); - -// policy function -static void policyFunc(const Tensor &input, cnrtDim3_t *k_dim, - cnrtFunctionType_t *k_type, int *channel_per_core, - int *max_number_hw_per_core, int *max_length_per_core) { - const int32_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - const int32_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - auto nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); - const int core_num = core_limit * cluster_limit; - const int batch_size = input.size(0); - const int time_size = input.size(1); - const int channel_size = input.size(2); - const int hw_size = input.size(3); - - const size_t size_per_channel = time_size * hw_size * input.itemsize(); - *channel_per_core = nram_size / size_per_channel; - int task_dim = 0; - if (*channel_per_core == 0) { - const size_t size_per_hw = hw_size * input.itemsize(); - *max_number_hw_per_core = nram_size / size_per_hw; - if (*max_number_hw_per_core <= 0) { - *max_length_per_core = nram_size / input.itemsize(); - } - int tmp_max_number_hw_per_core = - *max_number_hw_per_core > 0 ? *max_number_hw_per_core : 1; - const int loop_time = - (time_size / (tmp_max_number_hw_per_core)) + - ((time_size % (tmp_max_number_hw_per_core)) > 0 ? 1 : 0); - task_dim = batch_size * channel_size * loop_time < core_num - ? batch_size * channel_size * loop_time - : core_num; - } else { - task_dim = batch_size * channel_size < core_num ? batch_size * channel_size - : core_num; - } - - k_dim->x = core_limit; - k_dim->y = (task_dim / core_limit) > 0 ? (task_dim / core_limit) : 1; - k_dim->z = 1; - *k_type = CNRT_FUNC_TYPE_UNION1; -} +#include "mlu_common_helper.h" void TINShiftForwardMLUKernelLauncher(Tensor input, Tensor shift, Tensor output) { @@ -89,40 +31,37 @@ void TINShiftForwardMLUKernelLauncher(Tensor input, Tensor shift, if (input.size(1) == 0) { return; } - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - int channel_per_core = 0; - int max_number_hw_per_core = 0; - int max_length_per_core = 0; - policyFunc(input, &k_dim, &k_type, &channel_per_core, &max_number_hw_per_core, - &max_length_per_core); - - const int batch_size = input.size(0); - const int time_size = input.size(1); - const int channel_size = input.size(2); - const int hw_size = input.size(3); - const int group_size = shift.size(1); - int group_channel = channel_size / group_size; - // get tensor impl - auto input_impl = torch_mlu::getMluTensorImpl(input); - auto shift_impl = torch_mlu::getMluTensorImpl(shift); - auto output_impl = torch_mlu::getMluTensorImpl(output); + // set contiguous + auto input_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + input, input.suggest_memory_format()); + auto shift_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + shift, shift.suggest_memory_format()); + auto output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + output, output.suggest_memory_format()); - // get compute queue - auto queue = torch_mlu::getCurQueue(); + // get tensor impl + auto input_impl = torch_mlu::getMluTensorImpl(input_contiguous); + auto shift_impl = torch_mlu::getMluTensorImpl(shift_contiguous); + auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous); // get the mlu ptr auto input_ptr = input_impl->cnnlMalloc(); auto shift_ptr = shift_impl->cnnlMalloc(); auto output_ptr = output_impl->cnnlMalloc(); - cnrtDataType_t data_dtype = torch_mlu::toCnrtDtype(input.dtype()); + // set tensor descriptor + MluOpTensorDescriptor input_desc, shift_desc, output_desc; + input_desc.set(input_contiguous); + shift_desc.set(shift_contiguous); + output_desc.set(output_contiguous); - KernelTinShiftForward(k_dim, k_type, queue, input_ptr, shift_ptr, output_ptr, - batch_size, time_size, channel_size, hw_size, - group_size, group_channel, data_dtype, channel_per_core, - max_number_hw_per_core, max_length_per_core); + // get current handle + auto handle = mluOpGetCurrentHandle(); + + TORCH_MLUOP_CHECK(mluOpTinShiftForward(handle, input_desc.desc(), input_ptr, + shift_desc.desc(), shift_ptr, + output_desc.desc(), output_ptr)); } void TINShiftBackwardMLUKernelLauncher(Tensor grad_output, Tensor shift, @@ -148,41 +87,37 @@ void TINShiftBackwardMLUKernelLauncher(Tensor grad_output, Tensor shift, if (grad_output.size(1) == 0) { return; } - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - int channel_per_core = 0; - int max_number_hw_per_core = 0; - int max_length_per_core = 0; - policyFunc(grad_output, &k_dim, &k_type, &channel_per_core, - &max_number_hw_per_core, &max_length_per_core); - - const int batch_size = grad_output.size(0); - const int time_size = grad_output.size(1); - const int channel_size = grad_output.size(2); - const int hw_size = grad_output.size(3); - const int group_size = shift.size(1); - int group_channel = channel_size / group_size; - // get tensor impl - auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output); - auto shift_impl = torch_mlu::getMluTensorImpl(shift); - auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input); + // set contiguous + auto grad_output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + grad_output, grad_output.suggest_memory_format()); + auto shift_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + shift, shift.suggest_memory_format()); + auto grad_input_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + grad_input, grad_input.suggest_memory_format()); - // get compute queue - auto queue = torch_mlu::getCurQueue(); + // get tensor impl + auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_contiguous); + auto shift_impl = torch_mlu::getMluTensorImpl(shift_contiguous); + auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_contiguous); // get the mlu ptr auto grad_output_ptr = grad_output_impl->cnnlMalloc(); auto shift_ptr = shift_impl->cnnlMalloc(); auto grad_input_ptr = grad_input_impl->cnnlMalloc(); - cnrtDataType_t data_dtype = torch_mlu::toCnrtDtype(grad_output.dtype()); + // set tensor descriptor + MluOpTensorDescriptor grad_output_desc, shift_desc, grad_input_desc; + grad_output_desc.set(grad_output_contiguous); + shift_desc.set(shift_contiguous); + grad_input_desc.set(grad_input_contiguous); + + // get current handle + auto handle = mluOpGetCurrentHandle(); - KernelTinShiftBackward(k_dim, k_type, queue, grad_output_ptr, shift_ptr, - grad_input_ptr, batch_size, time_size, channel_size, - hw_size, group_size, group_channel, data_dtype, - channel_per_core, max_number_hw_per_core, - max_length_per_core); + TORCH_MLUOP_CHECK(mluOpTinShiftBackward( + handle, grad_output_desc.desc(), grad_output_ptr, shift_desc.desc(), + shift_ptr, grad_input_desc.desc(), grad_input_ptr)); } void tin_shift_forward_mlu(Tensor input, Tensor shift, Tensor output) {