Skip to content

Commit

Permalink
Add test for large batches in DeformConv2d (#2040)
Browse files Browse the repository at this point in the history
* Add test for large batches in DeformConv2d

* Clean-up and (try) fix DeformConv2d

* Simplifications and bugfixes

* Try fix CUDA now
  • Loading branch information
fmassa authored Apr 2, 2020
1 parent 979bb72 commit ccd797d
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 179 deletions.
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
return out

def get_fn_args(self, device, contiguous):
batch_sz = 1
batch_sz = 33
n_in_channels = 6
n_out_channels = 2
n_weight_grps = 2
Expand Down
139 changes: 52 additions & 87 deletions torchvision/csrc/cpu/DeformConv_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -713,55 +713,49 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(

auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset);
auto columns = at::zeros(
auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());

// Separate into blocks
grad_input = grad_input.view(
grad_input = grad_input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
input = input.view(
input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});

grad_out = grad_out.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_out_channels,
out_h,
out_w});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({grad_out.size(0),
n_weight_grps,
grad_out.size(1) / n_weight_grps,
grad_out.size(2),
grad_out.size(3),
grad_out.size(4)});

weight = weight.view({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
grad_offset = grad_offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});

grad_out = grad_out.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_weight_grps,
n_out_channels / n_weight_grps,
out_h,
out_w}).permute({0, 2, 3, 1, 4, 5});

weight = weight.reshape({n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});

columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});

for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
columns.zero_();
// Separate into weight groups
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) {
columns[g] = columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});

compute_grad_offset(
columns,
Expand Down Expand Up @@ -801,20 +795,9 @@ static std::tuple<at::Tensor, at::Tensor> deform_conv2d_backward_input_cpu(
grad_input[elt]);
}

grad_out = grad_out.view({grad_out.size(0),
grad_out.size(1) * grad_out.size(2),
grad_out.size(3),
grad_out.size(4),
grad_out.size(5)});
grad_out.transpose_(1, 2);
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});

grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w});
input = input.view({batch_sz, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
offset = offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});

return std::make_tuple(grad_input, grad_offset);
}
Expand Down Expand Up @@ -854,46 +837,36 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
long out_w = grad_out.size(3);

auto grad_weight = at::zeros_like(weight);
;
auto columns = at::zeros(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());

grad_out = grad_out.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_out_channels,
out_h,
out_w});
grad_out.transpose_(1, 2);

at::Tensor grad_out_buf = at::zeros_like(grad_out);
grad_out_buf.copy_(grad_out);
grad_out_buf = grad_out_buf.view({batch_sz / n_parallel_imgs,
n_out_channels,
n_parallel_imgs * out_h,
out_w});
grad_out_buf = grad_out_buf.view({grad_out_buf.size(0),
n_weight_grps,
grad_out_buf.size(1) / n_weight_grps,
grad_out_buf.size(2),
grad_out_buf.size(3)});

grad_out.transpose_(1, 2);
grad_out = grad_out.view({batch_sz, n_out_channels, out_h, out_w});

input = input.view(
at::Tensor grad_out_buf = grad_out.reshape(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_weight_grps,
n_out_channels / n_weight_grps,
out_h,
out_w}
).permute({0, 2, 3, 1, 4, 5}).contiguous();

input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.view({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
offset = offset.reshape({batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});

grad_weight = grad_weight.view({n_weight_grps,
grad_weight.size(0) / n_weight_grps,
grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3)});

auto columns = at::empty(
{n_weight_grps,
n_in_channels * weight_w * weight_h / n_weight_grps,
n_parallel_imgs * out_h * out_w},
input.options());

for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
deformable_im2col(
input[elt],
Expand All @@ -915,8 +888,6 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
n_offset_grps,
columns);

columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) {
grad_weight[g] =
grad_weight[g]
Expand All @@ -925,14 +896,8 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0))
.view_as(grad_weight[g]);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}

input = input.view({batch_sz, n_in_channels, in_h, in_w});
offset = offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});

grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3),
Expand Down
Loading

0 comments on commit ccd797d

Please sign in to comment.