diff --git a/mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp index efaae5e8e4..000f8882b1 100644 --- a/mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp @@ -14,11 +14,6 @@ void ball_query_forward_mlu(int b, int n, int m, float min_radius, float max_radius, int nsample, const Tensor new_xyz, const Tensor xyz, Tensor idx) { - MluOpTensorDescriptor new_xyz_desc, xyz_desc, idx_desc; - new_xyz_desc.set(new_xyz); - xyz_desc.set(xyz); - idx_desc.set(idx); - auto new_xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( new_xyz, new_xyz.suggest_memory_format()); auto xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( @@ -26,6 +21,11 @@ void ball_query_forward_mlu(int b, int n, int m, float min_radius, auto idx_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( idx, new_xyz.suggest_memory_format()); + MluOpTensorDescriptor new_xyz_desc, xyz_desc, idx_desc; + new_xyz_desc.set(new_xyz_contiguous); + xyz_desc.set(xyz_contiguous); + idx_desc.set(idx_contiguous); + auto new_xyz_impl = torch_mlu::getMluTensorImpl(new_xyz_contiguous); auto xyz_impl = torch_mlu::getMluTensorImpl(xyz_contiguous); auto idx_impl = torch_mlu::getMluTensorImpl(idx_contiguous);