From 430ab2f828f01cece85d104de13011d7d3579d97 Mon Sep 17 00:00:00 2001 From: lanzeshun <962034936@qq.com> Date: Wed, 18 Oct 2023 16:11:33 +0800 Subject: [PATCH] [fix] fix deform_conv.py for torch_npu v2.1 --- mmcv/ops/deform_conv.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmcv/ops/deform_conv.py b/mmcv/ops/deform_conv.py index 8251bc7328..c6cbba6779 100644 --- a/mmcv/ops/deform_conv.py +++ b/mmcv/ops/deform_conv.py @@ -51,10 +51,11 @@ def symbolic(g, @staticmethod def _npu_backward(ctx, grad_output): + import torch_npu input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \ ctx.saved_tensors grad_input, grad_weight, grad_offset_all, grad_bias = \ - torch.npu_deformable_conv2dbk( + torch_npu.npu_deformable_conv2dbk( input_tensor, grad_output, offset_out, weight, offset_all, kernel_size=[weight.shape[3], weight.shape[2]], stride=[1, 1, ctx.stride[0], ctx.stride[1]],