From 519abd539e38bfd190fb4b525fd13a4cd228bd80 Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Mon, 22 Feb 2021 20:04:10 +0800 Subject: [PATCH 1/2] add dygraph fleet --- paddleseg/core/train.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/paddleseg/core/train.py b/paddleseg/core/train.py index 18dc77fa7b..a1c25a9f27 100644 --- a/paddleseg/core/train.py +++ b/paddleseg/core/train.py @@ -96,14 +96,19 @@ def train(model, os.remove(save_dir) os.makedirs(save_dir) + # if nranks > 1: + # # Initialize parallel environment if not done. + # if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( + # ): + # paddle.distributed.init_parallel_env() + # ddp_model = paddle.DataParallel(model) + # else: + # ddp_model = paddle.DataParallel(model) + if nranks > 1: - # Initialize parallel environment if not done. - if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized( - ): - paddle.distributed.init_parallel_env() - ddp_model = paddle.DataParallel(model) - else: - ddp_model = paddle.DataParallel(model) + paddle.distributed.fleet.init(is_collective=True) + optimizer = paddle.distributed.fleet.distributed_optimizer(optimizer) + ddp_model = paddle.distributed.fleet.distributed_model(model) batch_sampler = paddle.io.DistributedBatchSampler( train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) @@ -159,8 +164,6 @@ def train(model, losses=losses, edges=edges) loss = sum(loss_list) - # loss.backward() - # optimizer.step() scaled = scaler.scale(loss) # scale the loss scaled.backward() # do backward @@ -180,9 +183,15 @@ def train(model, optimizer.step() lr = optimizer.get_lr() - if isinstance(optimizer._learning_rate, - paddle.optimizer.lr.LRScheduler): - optimizer._learning_rate.step() + + # update lr + if isinstance(optimizer, paddle.distributed.fleet.Fleet): + lr_sche = optimizer.user_defined_optimizer._learning_rate + else: + lr_sche = optimizer._learning_rate + if isinstance(lr_sche, paddle.optimizer.lr.LRScheduler): + lr_sche.step() + model.clear_gradients() avg_loss += loss.numpy()[0] if not avg_loss_list: From bc02429ddd63e3c88d671efcb0aa440a7c6836ca Mon Sep 17 00:00:00 2001 From: chenguowei01 Date: Tue, 23 Feb 2021 17:25:44 +0800 Subject: [PATCH 2/2] fix padding to symmetry --- paddleseg/models/backbones/hrnet.py | 23 +++++++++-------------- paddleseg/models/fcn.py | 2 -- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/paddleseg/models/backbones/hrnet.py b/paddleseg/models/backbones/hrnet.py index 40ed660d9d..52cec1718c 100644 --- a/paddleseg/models/backbones/hrnet.py +++ b/paddleseg/models/backbones/hrnet.py @@ -94,7 +94,7 @@ def __init__(self, out_channels=64, kernel_size=3, stride=2, - padding='same', + padding=1, bias_attr=False) self.conv_layer1_2 = layers.ConvBNReLU( @@ -102,7 +102,7 @@ def __init__(self, out_channels=64, kernel_size=3, stride=2, - padding='same', + padding=1, bias_attr=False) self.la1 = Layer1( @@ -243,7 +243,7 @@ def __init__(self, in_channels, out_channels, name=None): in_channels=in_channels[i], out_channels=out_channels[i], kernel_size=3, - padding='same', + padding=1, bias_attr=False)) else: residual = self.add_sublayer( @@ -253,7 +253,7 @@ def __init__(self, in_channels, out_channels, name=None): out_channels=out_channels[i], kernel_size=3, stride=2, - padding='same', + padding=1, bias_attr=False)) self.conv_bn_func_list.append(residual) @@ -322,7 +322,6 @@ def __init__(self, in_channels=num_channels, out_channels=num_filters, kernel_size=1, - padding='same', bias_attr=False) self.conv2 = layers.ConvBNReLU( @@ -330,14 +329,13 @@ def __init__(self, out_channels=num_filters, kernel_size=3, stride=stride, - padding='same', + padding=1, bias_attr=False) self.conv3 = layers.ConvBN( in_channels=num_filters, out_channels=num_filters * 4, kernel_size=1, - padding='same', bias_attr=False) if self.downsample: @@ -345,7 +343,6 @@ def __init__(self, in_channels=num_channels, out_channels=num_filters * 4, kernel_size=1, - padding='same', bias_attr=False) if self.has_se: @@ -390,13 +387,13 @@ def __init__(self, out_channels=num_filters, kernel_size=3, stride=stride, - padding='same', + padding=1, bias_attr=False) self.conv2 = layers.ConvBN( in_channels=num_filters, out_channels=num_filters, kernel_size=3, - padding='same', + padding=1, bias_attr=False) if self.downsample: @@ -404,7 +401,6 @@ def __init__(self, in_channels=num_channels, out_channels=num_filters, kernel_size=1, - padding='same', bias_attr=False) if self.has_se: @@ -567,7 +563,6 @@ def __init__(self, in_channels=in_channels[j], out_channels=out_channels[i], kernel_size=1, - padding='same', bias_attr=False)) self.residual_func_list.append(residual_func) elif j < i: @@ -582,7 +577,7 @@ def __init__(self, out_channels=out_channels[i], kernel_size=3, stride=2, - padding='same', + padding=1, bias_attr=False)) pre_num_filters = out_channels[i] else: @@ -594,7 +589,7 @@ def __init__(self, out_channels=out_channels[j], kernel_size=3, stride=2, - padding='same', + padding=1, bias_attr=False)) pre_num_filters = out_channels[j] self.residual_func_list.append(residual_func) diff --git a/paddleseg/models/fcn.py b/paddleseg/models/fcn.py index 4d2915976c..921c4d827a 100644 --- a/paddleseg/models/fcn.py +++ b/paddleseg/models/fcn.py @@ -113,7 +113,6 @@ def __init__(self, in_channels=backbone_channels[0], out_channels=channels, kernel_size=1, - padding='same', stride=1, bias_attr=bias) self.cls = nn.Conv2D( @@ -121,7 +120,6 @@ def __init__(self, out_channels=self.num_classes, kernel_size=1, stride=1, - padding=0, bias_attr=bias) self.init_weight()