Welcome to the Super-Resolution Generative Adversarial Network (SRGAN) repository! This code implements an image enhancement model using GANs to generate high-resolution images from low-resolution inputs. SRGANs are particularly useful for tasks like upscaling images with improved visual quality or to reduce the noise in an image.
The training process is encapsulated in the train function within the SRGAN.ipynb notebook. Here's a brief breakdown:
- High-resolution (HR) and low-resolution (LR) images are loaded using PyTorch DataLoader.
- The number of images in HR and LR folders is verified for consistency.
- The SRGAN consists of a generator (G) responsible for upscaling LR images and a discriminator (D) distinguishing between real HR images and fake HR images generated by G.
- Model weights are initialized with Kaiming He, and Adam optimizers are set up for both generator and discriminator.
- The training loop runs for the number of epochs provided as an hyperparameter.
- In each epoch, first the discriminator gets calculated the discriminator loss using error from real image and error from generated image during classification.
- After the discriminator update, generator loss which consists of vgg loss, adversarial loss and pixel loss are used to backprop and update generator parameters.
- We also tried to put a n:1 ratio for generator and discriminator update, where generator gets updated n times while discriminator gets updated only once in each epoch. This helps generator to keep up with discriminator and avoid scenarios where discriminator always wins. Due to high GPU ram requirements, we could not run this procedure and had to adjust to single updates resulting in acceptable performance by the generator.
- The discriminator is trained to distinguish between real HR images and fake HR images generated by the generator.
- The generator is trained to minimize the adversarial loss and generate realistic HR images.
Three loss components contribute to the overall generator loss (err_G):
Measures the difference between the generated image and the ground truth in a perceptually meaningful way.
Captures the discriminator's ability to distinguish between real and fake images.
Additional components contributing to the overall generator loss.
- Training progress is visualized by printing statistics for each batch, including discriminator and generator losses.
- Model checkpoints are saved periodically, enabling the resumption of training or deployment of a pre-trained model.
- The training loop iterates over multiple epochs, refining the model's performance. Memory management techniques, such as clearing variable data to free up GPU memory, are employed to ensure efficient usage during training.
- We tried to be efficient as much as possible in terms of GPU RAM usage of 16 GB in Google Colab Pro, but still it requires a lot of GPU Ram for GAN models with so many layers and parameters for generator and discriminator.
For deploying the trained model, follow these steps:
- Load Trained Model:
- Load the trained generator model using the SRGenerator class by uploading the D.pt and G.pt files in the model folder in google colab.
- Use the load_state_dict method to load the saved model parameters.
- Size of our stored model: D.pt - 60 MB and G.pt - 27 MB
- Inference:
- Provide a low-resolution image as input to the generator to obtain a high-resolution output, though a lot of hyperparameter and powerful compute power and hardware is required for excellent high-resolution image quality.
- Visualize and Save:
- Visualize the enhanced image and save it to a desired location.
- Clone the Repository:
-- git clone https://github.com/your-username/SRGAN.git
- Install Dependencies:
-- pip install -r requirements.txt
- Prepare Data:
- Organize HR and LR images in the data/ directory.
- Run Training:
- Execute the SRGAN.ipynb notebook to train the model.
- GPU support is recommended for faster training; the code automatically detects and uses CUDA if available.
- Hyperparameters and network architecture can be experimented with for potential improvements.