Skip to content

Commit

Permalink
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190
Browse files Browse the repository at this point in the history
  • Loading branch information
dmaugis committed Sep 13, 2018
1 parent fc1aca1 commit 000f298
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,13 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
norm_layer = get_norm_layer(norm_type=norm)

if netG == 'resnet_9blocks':
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,use_deconvolution=True)
elif netG == 'resnet_9blocks+':
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,use_deconvolution=False)
elif netG == 'resnet_6blocks':
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,use_deconvolution=True)
elif netG == 'resnet_6blocks+':
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,use_deconvolution=False)
elif netG == 'unet_128':
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
elif netG == 'unet_256':
Expand Down Expand Up @@ -140,7 +144,7 @@ def __call__(self, input, target_is_real):
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect',use_deconvolution=False):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
Expand Down Expand Up @@ -171,12 +175,18 @@ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_d

for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
if use_deconvolution:
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias)]
else:
model += [nn.Upsample(scale_factor = 2, mode='bilinear',align_corners=True),
nn.ReflectionPad2d(1),
nn.Conv2d(ngf * mult, int(ngf * mult / 2),kernel_size=3, stride=1, padding=0)]

model += [norm_layer(int(ngf * mult / 2)),nn.ReLU(True)]

model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
Expand Down

0 comments on commit 000f298

Please sign in to comment.