Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepFillv2 release #62

Open
minmax100 opened this issue Jun 15, 2018 · 52 comments
Open

DeepFillv2 release #62

minmax100 opened this issue Jun 15, 2018 · 52 comments
Labels
good first issue Good for newcomers

Comments

@minmax100
Copy link

Really nice work and great idea for the DeepFillv2!! Any plan and expected date to release the DeepFillv2 code?

@xhh232018
Copy link

I am quite amazed by the effect of DeepFillv2 !!! Also, I want to know the expected date of the code release since I want to know how it works. Thank you!

@yu45020
Copy link

yu45020 commented Jun 19, 2018

+1
I try to swap all convolution with the gated one in the Mobile Net V2 and add another decoder for image in-painting. But the result is not good.

Here is my PyTorch code for the gated convolution. I want to hear your feedback.

# BaseModule simply adds more functions to PyTorch's base module. 
class PartialConvBlock(BaseModule):
    # mask is binary, 0 is masked point, 1 is not

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=False, BN=True, activation=None):
        super(PartialConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
                              padding, dilation, groups, bias)

        self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
                                                 padding, dilation, groups, bias)

        out = []
        if BN:
            out.append(nn.BatchNorm2d(out_channels))
        if activation:
            out.append(activation)
        if len(out):
            self.out = nn.Sequential(*out)

    def forward(self, args):
        x, mask = args
        feature = self.conv(x)
        mask = self.mask_conv(mask)
        x = feature * torch.sigmoid(mask)
        if hasattr(self, 'out'):
            x = self.out(x)
        return x, mask

@JiahuiYu
Copy link
Owner

JiahuiYu commented Jun 19, 2018

First, thanks for your interest!
Currently we do NOT have an estimate date for code release. The only thing we know is that we will NOT release code in this summer. However, implementing deepfill v2 is relatively easy based on current code if you can understand all components of this repo.

@yu45020 It seems your code is not the gated convolution we proposed. Please check carefully of the tech report.

  1. We do not have x, mask = args, the only input is x.
  2. We do not have BN and I am not sure how it affects performance (In fact, in the earlier development of deepfill v1, we found BN slightly hurts color consistence. But if you can finetune model without BN in the end, it should be fine).
  3. In your code, the default activation seems None. We use ELU as activation.
  4. Gated Conv can be implemented as:
x1 = self.conv1(x)
x2 = self.conv2(x)
x = sigmoid(x2) * activation(x1)

Or a slightly faster implementation on GPU:

x = self.conv(x)
x1, x2 = split(x, 2) # split along channels 
x = sigmoid(x2) * activation(x1)

Hope it helps.

@yu45020
Copy link

yu45020 commented Jun 19, 2018

Great!
My code is used as a drop-in replacement for my current project, so there are BN and activation options. Sorry for the confusion. But I did misunderstand your paper on the mask. Batch norm seems work great in black/white picture, but it seems hurt color pictures.

By the way, I re-implemented 'Image Inpainting for Irregular Holes Using Partial Convolutions' before coming to your paper, so I was wondering where should I put the mask. But I was wrong.

@limchaos
Copy link

Really simple and efficient way ! Great idea!
Hi, have you compared with vanilla D architecture without Spectral Norm , I mean what's the benefit of SN in this work .

@JiahuiYu
Copy link
Owner

@limchaos Thanks for your interest first! You mentioned a good point: what's the performance difference w.r.t. different GAN architecture. The roadmap of GANs for inpainting is like: one vanilla D arch -> two global/local D arch -> SN-PatchGAN. We did not show comparison results in main paper because:

  1. One vanilla D arch has worse performance, which is studied in Globally and Locally Consistent Image Completion.
  2. GAN without SN has worse performance and is not stable, which is studied in Spectral Normalization for Generative Adversarial Networks
  3. We do have ablation experiments and we may add it to appendix in future.

@limchaos
Copy link

@JiahuiYu I see,thanks for your fast reply :)

@liouxy
Copy link

liouxy commented Jul 3, 2018

@JiahuiYu Thanks for your excellent work. I'm tring to reimplement the DeepFillv2 based on your code. But the results seem not good enough, can i ask you some details about it?

  1. In the paper, you mentioned that "our final objective function for inpainting network is only composed of pixel-wise reconstruction loss and SN-PatchGAN loss with default loss balancing hyper-parameter as 1:1". In the code v1 :
    losses['l1_loss'] = l1_alpha * tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x1)*spatial_discounting_mask(config))
    losses['ae_loss'] = l1_alpha * tf.reduce_mean(tf.abs(batch_pos - x1) * (1.-mask))
    losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
    losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']
    losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss']
    What is the detail form of the g_loss in DeepFillv2?

  2. Since hinge loss is used to calculate the gan loss, what kind of activation function will be used at the last layer of discriminator?

  3. Are there any neccesarry change in attention module since the mask is free-form?

Thank you again!

@JiahuiYu
Copy link
Owner

JiahuiYu commented Jul 3, 2018

Hi, thanks for your interest.

  1. Local patch loss is removed. It is AEloss + GANloss. L1 loss and AE loss is just same loss on different regions of image (inside mask or not).
  2. We use the setting of this repo: https://github.com/pfnet-research/sngan_projection/blob/master/updater.py
  3. No. The contextual attention is already implemented to support any shape masks.

@liouxy
Copy link

liouxy commented Jul 4, 2018

@JiahuiYu Thanks for the reply. So will the AEloss be calculated on the whole image with the same weights between foreground and background regions?

@JiahuiYu
Copy link
Owner

JiahuiYu commented Jul 4, 2018

@liouxy Yes. foreground:background is 1:1 . Since the mask is no longer rectangular, it is difficult to calculate the discounted weights proposed in CVPR paper. So the final solution is just a simple pixel-wise loss over input and output.

@xhh232018
Copy link

@JiahuiYu Thanks for your former help. I've read your paper carefully for several times and the implementation is almost done. However. some details confused me so I need your help

  1. In your paper, you mentioned that the input for discriminator is the predicted image, mask and the guidance(sketch) and the output is a 3-D shaped feature map with size of hwc. For the hinge loss calculation, can I directly put the 3-D feature map into the hinge loss formula?
  2. The spectral normalization is implemented in your algorithm. However, I ignored it at first. I saw it but I made no response to it. Today, I spent some time reading the paper of SNGAN. I realized that the convolution layers of discriminant network are based on spectral normalization. Is it right? (I thought they were vanilla convolution layers 2 days ago. Obviously, I was wrong) Also, the leakyRelu and batch normalization are implemented in SNGAN. Should I add them to the discriminant network?

@yu45020
Copy link

yu45020 commented Jul 16, 2018

@xhh232018
I am also re-implementing the paper for my project. For the first question, I guess you are right. If you flat the last layer into a linear with one node, then you are simply doing weighted sum. Here is the official SNGAN discriminator's last output

I fail to find the difference between these two approaches.

For the second one, I guess it is not necessary to put batch norm on top of spectral norm. PyTorch example

ps: Do you find the GAN hard to train ?

@JiahuiYu
Copy link
Owner

@xhh232018 @yu45020 I recommend this official implementation of SN-GAN.

@xhh232018
Copy link

@yu45020 I got the same problem of GAN training the g_loss is quite high. Have you solved this problem?

@yu45020
Copy link

yu45020 commented Jul 23, 2018

@xhh232018
Not yet. I follow the structure in deep fill and swap all convolutions with gated convolution. The GAN part is similar to the paper's part. The loss stays high.

I find L1 loss decreases quickly but hinge loss decreases slowly.

Here are what I have tried but don't work:

  • change the convolutions ( 5x5) in GAN to separable convolutions (like EffNet) in order to increase batch size
  • add batch norm in the generator
  • concatenate decoder with corresponding encoder

I have not trained the model for a long time since I rent GPU. I am wondering whether using pre-trained networks as discriminator will help.

But if I replace the partial convolution with gated convolution and use the complicated loss function from this paper, the improvement is visible within few iterations. I might misunderstand some details in GAN.

@xhh232018
Copy link

@yu45020 An expert told me that It is common that training GAN is a quite long process. Also the author mentioned that there is no batch-norm operation in the network. I have trained the Deepfill V1 for 5 days in order to get good results. I tried PCONV 2 weeks ago. It converges very quickly because its structure is relatively simple. I only trained the Deepfill V2 for 2 days. I will see what it will happen after several days training..

@yu45020
Copy link

yu45020 commented Jul 23, 2018

Thanks for the info. I notice the deep fill v2 removes batch norm to reduce color consistence since the author mentioned it above. But my application focuses on black/white images, and batch norm helps.

@shihaobai
Copy link

image @JiahuiYu what's the meaning of the input x? Does it refer to the incomplete image? Is it the same as what defined in v1 code? In the code v1: x = tf.concat([x, ones_x, ones_x*mask], axis=3)

@JiahuiYu
Copy link
Owner

@shihaobai You are correct. x means input incomplete image, which is exactly x = tf.concat([x, ones_x, ones_x*mask], axis=3).

@shihaobai
Copy link

shihaobai commented Aug 24, 2018

@JiahuiYu Thanks for your reply. It does help. I also have an another question. I tried to use relu as activation for gated-conv. But i found my d_loss reached convergence quickly so that my generator couldn't learn well. When i used elu instead of it. The model worked better. So i think there must be something wrong in my discriminator or the frequency i trained my generator and discriminator. How often do you train the discriminator and generator? and here is my discriminator code:
`

def dis_conv(x, cnum, ksize=5, stride=2, name='conv', training=True):

 x_shape=x.get_shape().as_list()

w=tf.get_variable(name=name+'_w',shape=[ksize,ksize,x_shape[-1]]+[cnum])

w=spectral_norm(w,name=name)   

x = tf.nn.conv2d(x,w,strides=[1,stride,stride,1],padding='SAME')

bias=tf.get_variable(name=name+'_bias',shape=[cnum])

return  tf.nn.leaky_relu(x+bias)`

`
def build_SNGAN_discriminator(x,mask,batch_size,reuse=False,training=True):
with tf.variable_scope('discriminator',reuse=reuse):

            cnum=64
	
            x=dis_conv(x,cnum,name='conv1',training=training)
	
            x=dis_conv(x,2*cnum,name='conv2',training=training)
	
            x=dis_conv(x,4*cnum,name='conv3',training=training)
	
            x=dis_conv(x,4*cnum,name='conv4',training=training)
	
            x=dis_conv(x,4*cnum,name='conv5',training=training)
            
            return x

`

@JiahuiYu
Copy link
Owner

@shihaobai I have six discriminator convolutions in build_SNGAN_discriminator.

@nogu-atsu
Copy link

nogu-atsu commented Aug 27, 2018

I implemented DeepFillv2 and now training it, but the quality of the generative results aren't good.

How many iterations did you need to get the high-quality results in the paper?
I'm training the model for 200,000 iterations with minibatch size 12. Is it enough?

@annzheng
Copy link

annzheng commented Sep 9, 2018

@JiahuiYu Hi, thanks you code and paper :)
I try to train the DeepFillv2 model now. Since there are too many parameters, the training speed is very slow. But I don't know how to slim the model width to make training more efficient. Should I reduce the number of conv layer or the channel number? Thanks in advance!

@JiahuiYu
Copy link
Owner

JiahuiYu commented Sep 9, 2018

@annzheng You can have a try. I train DeepFillv2 with Tesla V100 GPU, 16GB VRAM.

@JiahuiYu
Copy link
Owner

JiahuiYu commented Sep 9, 2018

@nogu-atsu FYI, I use mini-batch size 24. I train it on Tesla V100 GPU, 16 GB VRAM for five days, with GPU utilization almost 100%.

@xiaosean
Copy link

@JiahuiYu Hi, thank your code and paper
I have a question about the DeepFillv2 discriminator,
Are all layer in discriminator apply like below

x = leaky_relu(spectral_norm(x))

Including the first and last layer output?

@aiueogawa
Copy link

@JiahuiYu
Thank you for your response!
I also think discriminator with gated convolutions would work better for gated convolutions could pay attention to the part around the mask and user inputs. I'll try it later.

I have three more questions.

  1. What learning rate and optimizer did you use?
  2. In the following mention, how did you do that for free-form mask?

For efficiency, we can also restrict the search range of contextual attention module from the whole image to a local neighborhood, which can make the run-time significantly faster while maintaining overall quality of the results.

  1. How did you incorporate user inputs in implementation? Did you put them into mask ?

@JiahuiYu
Copy link
Owner

@aiueogawa Hi, here are some information:

  1. Same optimizer and learning rate with the released DeepFill v1.
  2. One can obtain one or more rectangle masks which cover(s) the free-form mask. The trick is mainly used in cases where the images are extremely large and the masks are relatively small.
  3. We use an additional channel for user sketch instead of put them into mask.

@theahura
Copy link

Hey Jiahui,

Thanks for being so responsive on this thread about deepfillv2.
Quick question: you mention in the paper that your model has 4.1M parameters. Is that just the generator model or does that include the discriminator too?

@JiahuiYu
Copy link
Owner

@theahura After training, we only use the generator for inpainting. Thus 4.1M is only the generator.

@theahura
Copy link

A quick follow up. In the paper you describe using a coarse network and a refinement network from your previous work. Is the coarse network trained separately/has its own loss as per the previous work? My own implementations of deepfillv2 have trouble learning anything useful in the coarse network.

Thanks again for your help and quick answers!

@JiahuiYu
Copy link
Owner

@theahura The coarse network exactly follows the training of this code, where we use a pixel-wise loss as additional supervision. They are trained jointly, instead of separately.

@theahura
Copy link

theahura commented Sep 27, 2018

So in DeepfillV2, both the final output of the overall network and the output of the coarse network get a pixelwise loss?

@JiahuiYu
Copy link
Owner

@theahura Yes.

@aiueogawa
Copy link

aiueogawa commented Oct 8, 2018

@JiahuiYu
I'm a bit confused about the activation function at the last layer in the discriminator in DeepFillv2.

You said two different things about it.

  1. you used the same setting of official SN-GANs where no activation (identity activation) is used in DeepFillv2 release #62 (comment).
  2. you used leaky relu in DeepFillv2 release #62 (comment).

I wonder which is right.

Intuitively, SN-PatchGAN loss wants output of the discriminator for fake images generated from the generator to be -1 while that for real images to be +1, hence the activation at the last layer is desired to be symmetric in its domain.
This idea leads me to no activation (identity activation).
Tanh function seems a good alternative which is not only symmetric but also seems 1-Lipschitz and lies in (-1, 1).
Considering Ex∼Pdata[ReLU(1−D(x))]+ Ez∼Pz[ReLU(1+D(G(z)))], I think D(x) should lie in (-1, 1).

What activation function did you use as that in the last layer in the discriminator for DeepFillv2?
And how do you think of it?

@JiahuiYu
Copy link
Owner

JiahuiYu commented Oct 8, 2018

@aiueogawa Hi thanks for your interest and good question.

Both my answers of #62 is correct. We use the loss function of SN-GANs as claimed in
#62 (comment) (Please do not assume layers and activation functions to be the same.), and we do use leaky relu for all convolutional layers in discriminator as claimed in
#62 (comment) (All layers including last layer).

What you are considering is about the activation function of last layer. You are correct that intuitively activation at the last layer is desired to be symmetric in its domain. However, I believe as long as the activation function can cover the domain of (-1, 1) and 1-Lipschitz condition holds, it is fine. Practically we find leaky relu (piecewise linear) works well. I guess using other activation function in last layer may not affect performance a lot. Feel free to report your results here if you have a try.

@aiueogawa
Copy link

@JiahuiYu Thank you! I see.
As you say, leaky ReLU also seems good and is.
In terms of the magnitude of gradients, leaky ReLU seems better than tanh, gradients of which might vanish.

@khemrajrathore
Copy link

khemrajrathore commented Oct 10, 2018

@JiahuiYu I implemented the algorithm for generating random mask as suggested in the paper. Following are the masks which the algorithm generates.
The value of hyperparameter are
imageHeight, imageWidth = 256, 256
maxVertex = 20
maxLength = 200
maxAngle = 210
maxBrushWidth = 30
Actually, the result is not satisfactory, what could be the ideal hyperparameter values??
See demo1, demo2, demo3, demo4

@JiahuiYu
Copy link
Owner

@khemrajrathore The joints between two lines should be very smooth as shown in our paper. It is done by setting the joint width the same as line width.
The original paper has a typo. Please take a reference of this algorithm

ps: to make this issue compact, I have modified your image as links instead of direct display.

@aiueogawa
Copy link

@JiahuiYu
What hyper parameters, maxVertex, maxLength, maxBrushWidth and maxAngle, did you use for free-form mask generation especially in CelebA-HQ?
And how many strokes in a free-form mask did you use?

@JiahuiYu
Copy link
Owner

JiahuiYu commented Oct 10, 2018

@aiueogawa @khemrajrathore I am using:

    min_num_vertex = 4
    max_num_vertex = 12
    mean_angle = 2*math.pi / 5
    angle_range = 2*math.pi / 15
    min_width = 12
    max_width = 40
    average_length = math.sqrt(H*H+W*W) / 8
    l = np.clip(np.random.normal(loc=average_length, scale= average_length//2), 0, 2*average_length)

@aiueogawa
Copy link

@JiahuiYu Thanks for your answer.
min_num_vertex, mean_angle, angle_range, min_width are not in your free-form mask generation algorithm.
I expect them to be as follows:

  • num_vertex is chosen from the uniform distribution between min_num_vertex and max_num_vertex
  • angle is chosen from the uniform distribution between mean_angle - angle_range and mean_angle + angle_range
  • width is chosen from the uniform distribution between min_width and max_width

However, how length is chosen?
Furthermore, how num_strokes is chosen?

@khemrajrathore
Copy link

@JiahuiYu @aiueogawa Thanks for the suggestion. With the value of hyperparameters discussed above these are the results
min_num_vertex = 4
max_num_vertex = 12
mean_angle = 2math.pi / 5
angle_range = 2
math.pi / 15
min_width = 12
max_width = 40
Demo0 Demo1 Demo2 Demo3

@aiueogawa
Copy link

@khemrajrathore
Your code still involves bugs around drawing a circle at a joint point, e.g. your joints are too big, as @JiahuiYu said.
The radius of a circle should be equal to the half width of a line.

@JiahuiYu
Copy link
Owner

JiahuiYu commented Oct 11, 2018

@aiueogawa I set it as average_length = math.sqrt(H*H+W*W) / 8. For more detail, I have updated code above.

@aiueogawa
Copy link

aiueogawa commented Oct 11, 2018

@JiahuiYu Thanks for revised information.

Q1:
Your discriminator of SN-PatchGAN produces HxWxC outputs and the loss defined in the following definition is also the same size HxWxC, because ReLU is an element wise operation and expectations are taken over samples.

image

How are the loss values are summed up into a single total loss.
Simple summation or calculate average of them?
This difference (sum or average) affects the proportion of contributions of reconstruction loss and SN-PatchGAN loss to the final objective of the generator, because you mentioned,

our final objective function for inpainting network is only composed of pixel-wise l1 reconstruction loss and SN-PatchGAN loss with default loss balancing hyper-parameter as 1 : 1.

in the DeepFillv2 paper.

Q2:
You describe the following two options for memory and computational efficiency of contextual attention layer in DeepFillv1.

  1. extracting background patches with strides to reduce the number of filters
  2. downscaling resolution of foreground inputs before convolution and upscaling attention map after propagation

and another option in DeepFillv2.

  1. restricting the search range of contextual attention module from the whole image to a local neighborhood

In DeepFillv2, what values did you use as stride and downscale_rate?
And did you restrict the search range in CelebA-HQ training without user-guidance channels.

Q3:
OK, then in DeepFillv2, what values did you use as stride and downscale_rate for contextual attention?

Q4:
How many times did you iterate training of a discriminator at a training of a generator?

Q5:
In DeepFillv2 paper,

The overall mask generation algorithm is illustrated in Algorithm 1. Additionally we can sample multiple strokes in single image to mask multiple regions.

Are multiple strokes used in training?

P.S.

Can you please ask all your questions for one time instead of long mutual conversations?

I already tried to ask all my related questions for one time but you always answered only part of my questions. Therefore, I asked you many times.
BTW, when a comment is updated for additional questions, can you notice the update?

@JiahuiYu
Copy link
Owner

JiahuiYu commented Oct 11, 2018

@aiueogawa I have merged your questions and deleted redundant ones. Can you please ask all your questions for one time instead of long mutual conversations? So others who see this issue can have a clean view.

Q1: Reduce mean, as shown here.
Q2: Option 3 is not used in CelebA-HQ case. All results in paper are without option 3. The option 3 is for real user cases when images are very large.
Q3: The same as the one in this repo by default.
Q4: I use 1:1 for training discriminator and generator.
Q5: I use random number of strokes between 1 to 4.

Repository owner deleted a comment from aiueogawa Oct 11, 2018
Repository owner deleted a comment from aiueogawa Oct 11, 2018
Repository owner deleted a comment from aiueogawa Oct 15, 2018
@JiahuiYu
Copy link
Owner

@aiueogawa I have opened a specific issue for your case. I do not understand what is confusing to you so we can communicate in issue #158. I am trying my best to help.

Repository owner deleted a comment from khemrajrathore Oct 23, 2018
Repository owner locked and limited conversation to collaborators Oct 23, 2018
Repository owner deleted a comment from aiueogawa Oct 23, 2018
@JiahuiYu JiahuiYu added the good first issue Good for newcomers label Aug 3, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests