diff --git a/paddleseg/models/sfnet.py b/paddleseg/models/sfnet.py index 43b34baa72..fe48911719 100644 --- a/paddleseg/models/sfnet.py +++ b/paddleseg/models/sfnet.py @@ -80,7 +80,7 @@ def forward(self, x): logit_list = [ F.interpolate( logit, - x.shape[2:], + paddle.shape(x)[2:], mode='bilinear', align_corners=self.align_corners) for logit in logit_list ] @@ -165,7 +165,7 @@ def forward(self, conv_out): out.append(self.dsn[i](f)) fpn_feature_list.reverse() - output_size = fpn_feature_list[0].shape[2:] + output_size = paddle.shape(fpn_feature_list[0])[2:] fusion_list = [fpn_feature_list[0]] for i in range(1, len(fpn_feature_list)): @@ -205,24 +205,25 @@ def __init__(self, inplane, outplane, kernel_size=3): padding=1, bias_attr=False) - def flow_warp(self, inputs, flow, size): - out_h, out_w = size - n, c, h, w = inputs.shape - norm = paddle.to_tensor([[[[out_w, out_h]]]]).astype('float32') - h = paddle.linspace(-1.0, 1.0, out_h).reshape([-1, 1]).tile([1, out_w]) - w = paddle.linspace(-1.0, 1.0, out_w).tile([out_h, 1]) - grid = paddle.concat([paddle.unsqueeze(w, 2), - paddle.unsqueeze(h, 2)], 2) - grid = grid.tile([n, 1, 1, 1]).astype('float32') - grid = grid + flow.transpose([0, 2, 3, 1]) / norm - output = F.grid_sample(inputs, grid) + def flow_warp(self, input, flow, size): + input_shape = paddle.shape(input) + norm = size[::-1].reshape([1, 1, 1, -1]) + norm.stop_gradient = True + h_grid = paddle.linspace(-1.0, 1.0, size[0]).reshape([-1, 1]) + h_grid = h_grid.tile([size[1]]) + w_grid = paddle.linspace(-1.0, 1.0, size[1]).reshape([-1, 1]) + w_grid = w_grid.tile([size[0]]).transpose([1, 0]) + grid = paddle.concat([w_grid.unsqueeze(2), h_grid.unsqueeze(2)], axis=2) + grid.unsqueeze(0).tile([input_shape[0], 1, 1, 1]) + grid = grid + paddle.transpose(flow, (0, 2, 3, 1)) / norm + + output = F.grid_sample(input, grid) return output def forward(self, x): low_feature, h_feature = x h_feature_orign = h_feature - h, w = low_feature.shape[2:] - size = (h, w) + size = paddle.shape(low_feature)[2:] low_feature = self.down_l(low_feature) h_feature = self.down_h(h_feature) h_feature = F.interpolate(