diff --git a/mmcv/ops/csrc/common/cuda/active_rotated_filter_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/active_rotated_filter_cuda_kernel.cuh index 36e41107eb..26118ac621 100644 --- a/mmcv/ops/csrc/common/cuda/active_rotated_filter_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/active_rotated_filter_cuda_kernel.cuh @@ -15,18 +15,19 @@ __global__ void active_rotated_filter_forward_cuda_kernel( const int nthreads, const scalar_t* weight_data, const int* indices_data, const int num_input_planes, const int num_output_planes, const int num_orientations, const int num_rotations, const int nEntry, - scalar_t* output_data) { + const int kH, const int kW, scalar_t* output_data) { CUDA_1D_KERNEL_LOOP(index, nthreads) { int l = index % nEntry; int j = (index / nEntry) % num_input_planes; int i = index / nEntry / num_input_planes; int k; + int fmIndex = (l / (kH * kW)) * kH * kW; scalar_t val = *(weight_data + index); for (k = 0; k < num_rotations; k++) { int idx = (int)(*(indices_data + l * num_rotations + k)) - 1; - scalar_t* target = output_data + - i * (num_rotations * num_input_planes * nEntry) + - k * (num_input_planes * nEntry) + j * (nEntry) + idx; + scalar_t* target = + output_data + i * (num_rotations * num_input_planes * nEntry) + + k * (num_input_planes * nEntry) + j * (nEntry) + idx + fmIndex; *target = val; } } @@ -37,12 +38,14 @@ __global__ void active_rotated_filter_backward_cuda_kernel( const int nthreads, const scalar_t* gradWeight_data, const int* indices_data, const int num_input_planes, const int num_output_planes, const int num_orientations, - const int num_rotations, const int nEntry, scalar_t* weight_data) { + const int num_rotations, const int nEntry, const int kH, const int kW, + scalar_t* weight_data) { CUDA_1D_KERNEL_LOOP(index, nthreads) { int l = index % nEntry; int j = (index / nEntry) % num_input_planes; int i = index / nEntry / num_input_planes; int k; + int fmIndex = (l / (kH * kW)) * kH * kW; scalar_t* val = weight_data + index; *val = 0; scalar_t tmp = 0; @@ -50,7 +53,7 @@ __global__ void active_rotated_filter_backward_cuda_kernel( int idx = (int)(*(indices_data + l * num_rotations + k)) - 1; scalar_t target = *(gradWeight_data + i * (num_rotations * num_input_planes * nEntry) + - k * (num_input_planes * nEntry) + j * (nEntry) + idx); + k * (num_input_planes * nEntry) + j * (nEntry) + idx + fmIndex); tmp = tmp + target; } *val = tmp; diff --git a/mmcv/ops/csrc/pytorch/cpu/active_rotated_filter.cpp b/mmcv/ops/csrc/pytorch/cpu/active_rotated_filter.cpp index aa5a8b3d51..c322b4044a 100644 --- a/mmcv/ops/csrc/pytorch/cpu/active_rotated_filter.cpp +++ b/mmcv/ops/csrc/pytorch/cpu/active_rotated_filter.cpp @@ -19,11 +19,12 @@ void active_rotated_filter_forward_cpu_kernel( for (l = 0; l < nEntry; l++) { int weightIndex = i * num_input_planes * nEntry + j * nEntry + l; T val = *(weightData + weightIndex); + int fmIndex = (l / (kH * kW)) * kH * kW; for (k = 0; k < num_rotations; k++) { int index = (int)(*(indicesData + l * num_rotations + k)) - 1; - T* target = outputData + - i * (num_rotations * num_input_planes * nEntry) + - k * (num_input_planes * nEntry) + j * (nEntry) + index; + T* target = + outputData + i * (num_rotations * num_input_planes * nEntry) + + k * (num_input_planes * nEntry) + j * (nEntry) + index + fmIndex; *target = val; } } @@ -48,11 +49,12 @@ void active_rotated_filter_backward_cpu_kernel( int gradInputIndex = i * num_input_planes * nEntry + j * nEntry + l; T* val = gradInputData + gradInputIndex; *val = 0; + int fmIndex = (l / (kH * kW)) * kH * kW; for (k = 0; k < num_rotations; k++) { int index = (int)(*(indicesData + l * num_rotations + k)) - 1; const T* target = gradOutputData + i * (num_rotations * num_input_planes * nEntry) + - k * (num_input_planes * nEntry) + j * (nEntry) + index; + k * (num_input_planes * nEntry) + j * (nEntry) + index + fmIndex; *val = *val + *target; } } diff --git a/mmcv/ops/csrc/pytorch/cuda/active_rotated_filter_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/active_rotated_filter_cuda.cu index 27fffb9fae..025e44148f 100644 --- a/mmcv/ops/csrc/pytorch/cuda/active_rotated_filter_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/active_rotated_filter_cuda.cu @@ -24,7 +24,7 @@ void ActiveRotatedFilterForwardCUDAKernelLauncher(const Tensor input, <<>>( output_size, input.data_ptr(), indices.data_ptr(), num_input_planes, num_output_planes, - num_orientations, num_rotations, nEntry, + num_orientations, num_rotations, nEntry, kH, kW, output.data_ptr()); }); AT_CUDA_CHECK(cudaGetLastError()); @@ -51,7 +51,7 @@ void ActiveRotatedFilterBackwardCUDAKernelLauncher(const Tensor grad_out, <<>>( output_size, grad_out.data_ptr(), indices.data_ptr(), num_input_planes, num_output_planes, - num_orientations, num_rotations, nEntry, + num_orientations, num_rotations, nEntry, kH, kW, grad_in.data_ptr()); }); AT_CUDA_CHECK(cudaGetLastError());