Skip to content

Commit

Permalink
Merge pull request #847 from wuyefeilin/benchmark_amp_dis
Browse files Browse the repository at this point in the history
  • Loading branch information
wuyefeilin authored Feb 23, 2021
2 parents e881e11 + bc02429 commit 4fd3da6
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 28 deletions.
33 changes: 21 additions & 12 deletions paddleseg/core/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,19 @@ def train(model,
os.remove(save_dir)
os.makedirs(save_dir)

# if nranks > 1:
# # Initialize parallel environment if not done.
# if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
# ):
# paddle.distributed.init_parallel_env()
# ddp_model = paddle.DataParallel(model)
# else:
# ddp_model = paddle.DataParallel(model)

if nranks > 1:
# Initialize parallel environment if not done.
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
):
paddle.distributed.init_parallel_env()
ddp_model = paddle.DataParallel(model)
else:
ddp_model = paddle.DataParallel(model)
paddle.distributed.fleet.init(is_collective=True)
optimizer = paddle.distributed.fleet.distributed_optimizer(optimizer)
ddp_model = paddle.distributed.fleet.distributed_model(model)

batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
Expand Down Expand Up @@ -159,8 +164,6 @@ def train(model,
losses=losses,
edges=edges)
loss = sum(loss_list)
# loss.backward()
# optimizer.step()

scaled = scaler.scale(loss) # scale the loss
scaled.backward() # do backward
Expand All @@ -180,9 +183,15 @@ def train(model,
optimizer.step()

lr = optimizer.get_lr()
if isinstance(optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()

# update lr
if isinstance(optimizer, paddle.distributed.fleet.Fleet):
lr_sche = optimizer.user_defined_optimizer._learning_rate
else:
lr_sche = optimizer._learning_rate
if isinstance(lr_sche, paddle.optimizer.lr.LRScheduler):
lr_sche.step()

model.clear_gradients()
avg_loss += loss.numpy()[0]
if not avg_loss_list:
Expand Down
23 changes: 9 additions & 14 deletions paddleseg/models/backbones/hrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def __init__(self,
out_channels=64,
kernel_size=3,
stride=2,
padding='same',
padding=1,
bias_attr=False)

self.conv_layer1_2 = layers.ConvBNReLU(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=2,
padding='same',
padding=1,
bias_attr=False)

self.la1 = Layer1(
Expand Down Expand Up @@ -243,7 +243,7 @@ def __init__(self, in_channels, out_channels, name=None):
in_channels=in_channels[i],
out_channels=out_channels[i],
kernel_size=3,
padding='same',
padding=1,
bias_attr=False))
else:
residual = self.add_sublayer(
Expand All @@ -253,7 +253,7 @@ def __init__(self, in_channels, out_channels, name=None):
out_channels=out_channels[i],
kernel_size=3,
stride=2,
padding='same',
padding=1,
bias_attr=False))
self.conv_bn_func_list.append(residual)

Expand Down Expand Up @@ -322,30 +322,27 @@ def __init__(self,
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
padding='same',
bias_attr=False)

self.conv2 = layers.ConvBNReLU(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
stride=stride,
padding='same',
padding=1,
bias_attr=False)

self.conv3 = layers.ConvBN(
in_channels=num_filters,
out_channels=num_filters * 4,
kernel_size=1,
padding='same',
bias_attr=False)

if self.downsample:
self.conv_down = layers.ConvBN(
in_channels=num_channels,
out_channels=num_filters * 4,
kernel_size=1,
padding='same',
bias_attr=False)

if self.has_se:
Expand Down Expand Up @@ -390,21 +387,20 @@ def __init__(self,
out_channels=num_filters,
kernel_size=3,
stride=stride,
padding='same',
padding=1,
bias_attr=False)
self.conv2 = layers.ConvBN(
in_channels=num_filters,
out_channels=num_filters,
kernel_size=3,
padding='same',
padding=1,
bias_attr=False)

if self.downsample:
self.conv_down = layers.ConvBNReLU(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=1,
padding='same',
bias_attr=False)

if self.has_se:
Expand Down Expand Up @@ -567,7 +563,6 @@ def __init__(self,
in_channels=in_channels[j],
out_channels=out_channels[i],
kernel_size=1,
padding='same',
bias_attr=False))
self.residual_func_list.append(residual_func)
elif j < i:
Expand All @@ -582,7 +577,7 @@ def __init__(self,
out_channels=out_channels[i],
kernel_size=3,
stride=2,
padding='same',
padding=1,
bias_attr=False))
pre_num_filters = out_channels[i]
else:
Expand All @@ -594,7 +589,7 @@ def __init__(self,
out_channels=out_channels[j],
kernel_size=3,
stride=2,
padding='same',
padding=1,
bias_attr=False))
pre_num_filters = out_channels[j]
self.residual_func_list.append(residual_func)
Expand Down
2 changes: 0 additions & 2 deletions paddleseg/models/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,13 @@ def __init__(self,
in_channels=backbone_channels[0],
out_channels=channels,
kernel_size=1,
padding='same',
stride=1,
bias_attr=bias)
self.cls = nn.Conv2D(
in_channels=channels,
out_channels=self.num_classes,
kernel_size=1,
stride=1,
padding=0,
bias_attr=bias)
self.init_weight()

Expand Down

0 comments on commit 4fd3da6

Please sign in to comment.