Skip to content

Commit

Permalink
Merge pull request #5 from LielinJiang/benchmark
Browse files Browse the repository at this point in the history
for Benchmark test
  • Loading branch information
LielinJiang authored Aug 6, 2020
2 parents 3211114 + 3586d7d commit c56dbd8
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 20 deletions.
1 change: 1 addition & 0 deletions configs/cyclegan_cityscapes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dataset:
train:
name: UnpairedDataset
dataroot: data/cityscapes
num_workers: 4
phase: train
max_dataset_size: inf
direction: AtoB
Expand Down
1 change: 1 addition & 0 deletions configs/pix2pix_cityscapes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dataset:
train:
name: PairedDataset
dataroot: data/cityscapes
num_workers: 4
phase: train
max_dataset_size: inf
direction: BtoA
Expand Down
2 changes: 1 addition & 1 deletion ppgan/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,6 @@ def build_dataloader(cfg, is_train=True):
batch_size = cfg.get('batch_size', 1)
num_workers = cfg.get('num_workers', 0)

dataloader = DictDataLoader(dataset, batch_size, is_train)
dataloader = DictDataLoader(dataset, batch_size, is_train, num_workers)

return dataloader
32 changes: 28 additions & 4 deletions ppgan/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import time

import logging
import paddle

from paddle.imperative import ParallelEnv
from paddle.imperative import ParallelEnv, DataParallel

from ..datasets.builder import build_dataloader
from ..models.builder import build_model
Expand All @@ -22,10 +23,13 @@ def __init__(self, cfg):

# build model
self.model = build_model(cfg)
# multiple gpus prepare
if ParallelEnv().nranks > 1:
self.distributed_data_parallel()

self.logger = logging.getLogger(__name__)

# base config
# self.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
self.output_dir = cfg.output_dir
self.epochs = cfg.epochs
self.start_epoch = 0
Expand All @@ -37,25 +41,39 @@ def __init__(self, cfg):
self.cfg = cfg

self.local_rank = ParallelEnv().local_rank

# time count
self.time_count = {}

def distributed_data_parallel(self):
strategy = paddle.imperative.prepare_context()
for name in self.model.model_names:
if isinstance(name, str):
net = getattr(self.model, 'net' + name)
setattr(self.model, 'net' + name, DataParallel(net, strategy))

def train(self):

for epoch in range(self.start_epoch, self.epochs):
start_time = time.time()
self.current_epoch = epoch
start_time = step_start_time = time.time()
for i, data in enumerate(self.train_dataloader):
data_time = time.time()
self.batch_id = i
# unpack data from dataset and apply preprocessing
# data input should be dict
self.model.set_input(data)
self.model.optimize_parameters()


self.data_time = data_time - step_start_time
self.step_time = time.time() - step_start_time
if i % self.log_interval == 0:
self.print_log()

if i % self.visual_interval == 0:
self.visual('visual_train')

step_start_time = time.time()
self.logger.info('train one epoch time: {}'.format(time.time() - start_time))
if epoch % self.weight_interval == 0:
self.save(epoch, 'weight', keep=-1)
Expand Down Expand Up @@ -98,6 +116,12 @@ def print_log(self):
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)

if hasattr(self, 'data_time'):
message += 'reader cost: %.5fs ' % self.data_time

if hasattr(self, 'step_time'):
message += 'batch cost: %.5fs' % self.step_time

# print the message
self.logger.info(message)

Expand Down
18 changes: 16 additions & 2 deletions ppgan/models/cycle_gan_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import paddle
from paddle.imperative import ParallelEnv
from .base_model import BaseModel

from .builder import MODELS
Expand Down Expand Up @@ -137,7 +138,13 @@ def backward_D_basic(self, netD, real, fake):
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D.backward()
# loss_D.backward()
if ParallelEnv().nranks > 1:
loss_D = netD.scale_loss(loss_D)
loss_D.backward()
netD.apply_collective_grads()
else:
loss_D.backward()
return loss_D

def backward_D_A(self):
Expand Down Expand Up @@ -177,7 +184,14 @@ def backward_G(self):
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()

if ParallelEnv().nranks > 1:
self.loss_G = self.netG_A.scale_loss(self.loss_G)
self.loss_G.backward()
self.netG_A.apply_collective_grads()
self.netG_B.apply_collective_grads()
else:
self.loss_G.backward()

def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
Expand Down
6 changes: 1 addition & 5 deletions ppgan/models/generators/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,8 @@ def __init__(self, input_nc, output_nc, ngf=64, norm_type='instance', use_dropou
else:
use_bias = norm_layer == nn.InstanceNorm

print('norm layer:', norm_layer, 'use bias:', use_bias)

model = [ReflectionPad2d(3),
nn.Conv2D(input_nc, ngf, filter_size=7, padding=0, bias_attr=use_bias),
# nn.nn.Conv2D(input_nc, ngf, filter_size=7, padding=0, bias_attr=use_bias),
norm_layer(ngf),
nn.ReLU()]

Expand All @@ -62,8 +59,7 @@ def __init__(self, input_nc, output_nc, ngf=64, norm_type='instance', use_dropou
model += [
nn.Conv2DTranspose(ngf * mult, int(ngf * mult / 2),
filter_size=3, stride=2,
padding=1, #output_padding=1,
# padding='same', #output_padding=1,
padding=1,
bias_attr=use_bias),
Pad2D(paddings=[0, 1, 0, 1], mode='constant', pad_value=0.0),
norm_layer(int(ngf * mult / 2)),
Expand Down
18 changes: 14 additions & 4 deletions ppgan/models/pix2pix_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import paddle
from paddle.imperative import ParallelEnv
from .base_model import BaseModel

from .builder import MODELS
Expand Down Expand Up @@ -43,7 +44,6 @@ def __init__(self, opt):
# define networks (both generator and discriminator)
self.netG = build_generator(opt.model.generator)


# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if self.isTrain:
self.netD = build_discriminator(opt.model.discriminator)
Expand Down Expand Up @@ -98,7 +98,12 @@ def backward_D(self):
self.loss_D_real = self.criterionGAN(pred_real, True)
# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
if ParallelEnv().nranks > 1:
self.loss_D = self.netD.scale_loss(self.loss_D)
self.loss_D.backward()
self.netD.apply_collective_grads()
else:
self.loss_D.backward()

def backward_G(self):
"""Calculate GAN and L1 loss for the generator"""
Expand All @@ -110,8 +115,13 @@ def backward_G(self):
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
# combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1
# self.loss_G = self.loss_G_L1
self.loss_G.backward()

if ParallelEnv().nranks > 1:
self.loss_G = self.netG.scale_loss(self.loss_G)
self.loss_G.backward()
self.netG.apply_collective_grads()
else:
self.loss_G.backward()

def optimize_parameters(self):
# compute fake images: G(A)
Expand Down
6 changes: 2 additions & 4 deletions ppgan/utils/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ def save(state_dicts, file_name):

def convert(state_dict):
model_dict = {}
# name_table = {}

for k, v in state_dict.items():
if isinstance(v, (paddle.framework.Variable, paddle.imperative.core.VarBase)):
model_dict[k] = v.numpy()
else:
model_dict[k] = v
return state_dict
# name_table[k] = v.name
# model_dict["StructuredToParameterName@@"] = name_table

return model_dict

final_dict = {}
Expand Down

0 comments on commit c56dbd8

Please sign in to comment.