Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SRGAN doesn't work without pretraining a SRResNet (How to fix) #43

Open
neilthefrobot opened this issue Mar 4, 2021 · 5 comments
Open

Comments

@neilthefrobot
Copy link

neilthefrobot commented Mar 4, 2021

The GAN losses don't stabilize unless you first pretrain the generator. The network will still end up being able to improve image quality if you don't, but it will be only from the VGG loss and the GAN part will be basically useless (at least in my experience)
The fix is extremely easy. In "get_gan_network" change the compilation to be
gan = Model(gan_input, x) gan.compile(loss='mse', optimizer=optimizer)

and in the training loop just comment out the training of the discriminator.
This trains the generator to minimize the MSE between the training inputs and the training targets with no GAN.
Once this model is trained, removed these changes and continue training (using VGG perceptual loss + GAN loss and training the descriminator in the training loop)

@HeeebsInc
Copy link

HeeebsInc commented Mar 7, 2021

do you have the code for this? I am currently stuck trying to implement this. It seems as though it is working but when I predict on the generator I am only getting a black image (all zeros), even when I denormalize. I also tried using the MSE optimizer like you have a above but the loss output is negative which should not be the case. Your help is greatly appreciated!!!

@neilthefrobot
Copy link
Author

All you need to change is get_gan_network() to be -
`def get_gan_network(discriminator, shape, generator, optimizer):

gan_input = Input(shape=shape)  

x = generator(gan_input)  

gan = Model(inputs=gan_input, outputs=x)  

gan.compile(loss='mse, optimizer=optimizer)  

return gan`

And then remove all of this (the training of the discriminator in the training loop) -
` rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)

        image_batch_hr = x_train_hr[rand_nums]
        image_batch_lr = x_train_lr[rand_nums]
        generated_images_sr = generator.predict(image_batch_lr)

        real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
        fake_data_Y = np.random.random_sample(batch_size)*0.2
        
        discriminator.trainable = True
        
        d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
        d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
        discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)`

And set the line that trains the generator to just try to turn LR into HR without any gan loss -
gan_loss = gan.train_on_batch(image_batch_lr, image_batch_hr)

@HeeebsInc
Copy link

thank you so much!!! You are a life saver

@siweic0818
Copy link

All you need to change is get_gan_network() to be -
`def get_gan_network(discriminator, shape, generator, optimizer):

gan_input = Input(shape=shape)  

x = generator(gan_input)  

gan = Model(inputs=gan_input, outputs=x)  

gan.compile(loss='mse, optimizer=optimizer)  

return gan`

And then remove all of this (the training of the discriminator in the training loop) -
` rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)

        image_batch_hr = x_train_hr[rand_nums]
        image_batch_lr = x_train_lr[rand_nums]
        generated_images_sr = generator.predict(image_batch_lr)

        real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
        fake_data_Y = np.random.random_sample(batch_size)*0.2
        
        discriminator.trainable = True
        
        d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
        d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
        discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)`

And set the line that trains the generator to just try to turn LR into HR without any gan loss -
gan_loss = gan.train_on_batch(image_batch_lr, image_batch_hr)

Thanks for the code. So you are basically training a non-GAN this way? Once this part is trained, do you save and re-load the weights and then continue training the GAN model?

@neilthefrobot
Copy link
Author

neilthefrobot commented Mar 11, 2021

Thanks for the code. So you are basically training a non-GAN this way? Once this part is trained, do you save and re-load the weights and then continue training the GAN model?

Yes. Fast.ai calls this "NoGan training"
More info - https://www.fast.ai/2019/05/03/decrappify/
You are basically just training a Gan but you pre train the generator so that it is already making decent images from the start. It usually keeps things more stabilized.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants