Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

any suggestion on how to get rid of the checkered effect #190

Closed
NeuralBricolage opened this issue Jan 18, 2018 · 14 comments
Closed

any suggestion on how to get rid of the checkered effect #190

NeuralBricolage opened this issue Jan 18, 2018 · 14 comments

Comments

@NeuralBricolage
Copy link

i'm training CycleGAN on my food photos and my sketches - the results are fascinating esp colors and shapes - if only i knew how to get rid of that checkered effect...
38744547705_046ce1d663_o
39619927772_e0c97ed1b0_o

@phillipi
Copy link
Collaborator

Very cool! Care to share what the inputs look like too?

This distill paper talks about one of the causes of the checkerboard artifacts. You can fix that issue by switching from "deconvolution" to nearest-neighbor upsampling followed by regular convolution. I think @ssnl may have implemented this at some point.

We've also noticed that sometimes the checkboard artifacts go away if you simply train long enough. Maybe try training a bit longer.

Another cause of repetitive artifacts can be that the discriminator's receptive field is too small. For some discussion on this, please see Section 4.4 and Figure 6 of the pix2pix paper. The issue is that if the discriminator looks at too myopic a region, it won't notice that textures are repeating. I think this is probably not the case in your results, but it's something to keep in mind.

@NeuralBricolage
Copy link
Author

NeuralBricolage commented Jan 18, 2018

thank you Phillip, great to have the path forward - checking the links!
the original images in my training sets are pretty high res but i'm bounded by GTX 1080 - both A and B DSs are around 1k each
i trained on loadSize=384, loadSize=1024 and fineSize=384, while testing on a higher res and was not that happy with the results, the original image being too pronounced...
so that's where i an now, training with loadSize=768, fineSize=384

@ssnl
Copy link
Collaborator

ssnl commented Jan 18, 2018

@AllAwake Cool results!

Here is the implementation of resize-conv I used. It remove the checkerboard artifacts during early training. You may find it useful.

                          nn.Upsample(scale_factor = 2, mode='bilinear'),
                          nn.ReflectionPad2d(1),
                          nn.Conv2d(ngf * mult, int(ngf * mult / 2),
                                             kernel_size=3, stride=1, padding=0),

It should replace the ConvTranspose2d in ResnetGenerator.

dmaugis added a commit to dmaugis/pytorch-CycleGAN-and-pix2pix that referenced this issue Sep 13, 2018
franzmoca added a commit to franzmoca/pytorch-CycleGAN-and-pix2pix that referenced this issue Oct 11, 2018
@codaibk
Copy link

codaibk commented Nov 14, 2018

Hi @ssnl and @phillipi
I currently research about lower-to-high resolution image. My Generator and Discriminator artichecture are same as your cycleGan. I followed your guide to replace the ConvTranspose2d but it seems the checkerboard artifacts still appear in my result. Follow the paper distill, they mentioned about resize the image (using nearest-neighbor interpolation or bilinear interpolation) and also changing something in Discriminator. Could you please tell me how we implement it to remove the checkerboard artifacts?. This is my result
Input
116_real_a
result
116_fake_b2

@junyanz
Copy link
Owner

junyanz commented Nov 14, 2018

I haven't tried the distill tricks by myself. For discriminators, they mentioned that you can replace stride 2 conv with a regular 3x3 conv.

@codaibk
Copy link

codaibk commented Nov 15, 2018

I haven't tried the distill tricks by myself. For discriminators, they mentioned that you can replace stride 2 conv with a regular 3x3 conv.

Does Regular 3x3 conv mean that we just need to change stride =1 in this case ?

@junyanz
Copy link
Owner

junyanz commented Nov 15, 2018

Yeah, I guess you also need to add a downsample layer. You can look at Table 2 in the progressive gans paper.

@yichunk
Copy link
Contributor

yichunk commented Jan 30, 2019

@AllAwake Cool results!

Here is the implementation of resize-conv I used. It remove the checkerboard artifacts during early training. You may find it useful.

                          nn.Upsample(scale_factor = 2, mode='bilinear'),
                          nn.ReflectionPad2d(1),
                          nn.Conv2d(ngf * mult, int(ngf * mult / 2),
                                             kernel_size=3, stride=1, padding=0),

It should replace the ConvTranspose2d in ResnetGenerator.

Hi @ssnl
may I ask why you choose bilinear upsampling instead of nearest-neighbor one? The distill paper pointed out the result of nearest-neighbor interpolation should be better.

@ssnl
Copy link
Collaborator

ssnl commented Jan 30, 2019

@jacky841102 You could try nearest neighbor. I think that it didn’t make much difference in the dataset I tried.

@mrgloom
Copy link

mrgloom commented Sep 3, 2019

I wonder why pix2pixHD is using nn.ConvTranspose2d anyway?
NVIDIA/pix2pixHD#149

Same with StyleGAN
https://github.com/NVlabs/stylegan/blob/f3a044621e2ab802d40940c16cc86042ae87e100/training/networks_stylegan.py#L174

and ProGAN
https://github.com/NVlabs/stylegan/blob/f3a044621e2ab802d40940c16cc86042ae87e100/training/networks_progan.py#L89

@mrgloom
Copy link

mrgloom commented Sep 3, 2019

Also according to this http://warmspringwinds.github.io/tensorflow/tf-slim/2016/11/22/upsampling-and-image-segmentation-with-tensorflow-and-tf-slim/
transposed convolution can be initialized with bilinear filter.

@yanfei-zhang-95
Copy link

yanfei-zhang-95 commented Oct 2, 2019

@AllAwake Cool results!

Here is the implementation of resize-conv I used. It remove the checkerboard artifacts during early training. You may find it useful.

                          nn.Upsample(scale_factor = 2, mode='bilinear'),
                          nn.ReflectionPad2d(1),
                          nn.Conv2d(ngf * mult, int(ngf * mult / 2),
                                             kernel_size=3, stride=1, padding=0),

It should replace the ConvTranspose2d in ResnetGenerator.

This method works pretty well, but it would also cause unstable training, therefore I applied an exponential learning rate decay and small initial learning rate to ensure that it works well enough. I am currently doing another training and see what will happen then

@AlexTS1980
Copy link

@ssnl Thanks for the solution. Did anyone successfully implement it in Generators that monotonically upsample the input? In this example, for example, if you add a padding, the output will be exactly twice the size of the input;

nn.Upsample(scale_factor = 2, mode='bilinear'),
                          nn.ReflectionPad2d(1),
                          nn.Conv2d(ngf * mult, int(ngf * mult / 2),
                                             kernel_size=3, stride=1, padding=1),

@AlexTS1980
Copy link

@AllAwake Cool results!

Here is the implementation of resize-conv I used. It remove the checkerboard artifacts during early training. You may find it useful.

                          nn.Upsample(scale_factor = 2, mode='bilinear'),
                          nn.ReflectionPad2d(1),
                          nn.Conv2d(ngf * mult, int(ngf * mult / 2),
                                             kernel_size=3, stride=1, padding=0),

It should replace the ConvTranspose2d in ResnetGenerator.

I like how it totally gets rid of the checkerboard artefacts, but it comes at a cost of actually learning useful features. I'm trying to balance these effects.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants