The source code of the project, as well as a detailed description in Ukrainian and instructions for running, training, and testing, can be found in my GitHub repository: https://github.com/andersenbel/rd9_mGAN.
To develop and fine-tune GANs for image restoration tasks, specifically for:
Evaluate the model performance using the following metrics:
We selected CIFAR-10 for the following reasons:
torchvision
, unlike CelebA, which requires additional tools for downloading.The generator takes low-resolution input images (32x32) and restores them to high resolution (128x128).
ConvTranspose2D
) for upscaling.BatchNorm2D
) for training stability.ReLU
for hidden layers.tanh
for the output layer.The discriminator classifies whether an image is real or generated.
Conv2D
) for feature extraction.LeakyReLU
for hidden layers.sigmoid
for classification.Discriminator:
Binary Cross-Entropy Loss
(BCELoss).Adam
with a learning rate of 1e-4
.Generator:
Mean Squared Error
(MSE) for image restoration.Adam
with a learning rate of 1e-4
.The models were trained with the following parameters:
generator_epoch_{number}.pth
discriminator_epoch_{number}.pth
train.py # Code for training GAN evaluate.py # Code for evaluating models (GAN and SRGAN) generator.py # Generator architecture discriminator.py # Discriminator architecture srgan.py # SRGAN generator architecture (optional)
Epoch | PSNR | SSIM |
---|---|---|
1 | 17.4899 | 0.7337 |
2 | 14.7227 | 0.8234 |
3 | 20.4272 | 0.8914 |
4 | 22.0976 | 0.9126 |
5 | 18.8058 | 0.7790 |
6 | 20.8179 | 0.9182 |
7 | 23.3236 | 0.9427 |
8 | 22.7589 | 0.9459 |
9 | 23.6007 | 0.9342 |
10 | 22.3043 | 0.9427 |
GAN:
SRGAN:
Metrics improve with epochs:
Optimal quality:
Visual comparison:
Results confirm the effectiveness of the model in restoring images from low resolution.
./setup_env.sh source ./_env/bin/activate
pip install torch torchvision matplotlib gdown scikit-image
python train.py --dataset_path ./data --epochs 50 --batch_size 16
python train.py --dataset_path ./data --epochs 50 --batch_size 16 --resize 32
python evaluate.py --model_dir checkpoints --dataset_path ./data --batch_size 16 --max_images 5