This project implements a state-of-the-art Generative Adversarial Network (GAN) using several advanced techniques like WGAN-GP, Self-Attention, Conditional GAN (cGAN) with multi-attribute control, Adaptive Discriminator Augmentation (ADA), and Exponential Moving Average (EMA) for generating high-quality images. The project is built using PyTorch and can be extended to handle various datasets.
Make sure you have the following libraries installed in your environment:
pip install torch torchvision tqdm numpy pillow tensorboard
Clone this repository to your local machine:
git clone https://github.com/yourusername/gan-project.git cd gan-project
Create the necessary directories:
mkdir checkpoints generated_images
python main.py
You can modify the hyperparameters in main.py:
After training, you can generate new images by loading the trained EMA generator model and feeding random noise vectors. The generated images will be saved to the generated_images/ folder.
from model import GeneratorWithSpectralNorm import torch from torchvision.utils import save_image # Load the generator model generator = GeneratorWithSpectralNorm(latent_dim=100, num_classes=10).to('cuda') generator.load_state_dict(torch.load('checkpoints/generator_epoch_199.pth')) generator.eval() # Generate images fixed_noise = torch.randn(64, 100, 1, 1).to('cuda') with torch.no_grad(): fake_images = generator(fixed_noise, labels) # Labels for cGAN can be passed here save_image(fake_images, "generated_images/final_output.png", normalize=True)
To resume training from a checkpoint or to use a pre-trained model, load the saved checkpoint like this:
generator.load_state_dict(torch.load('checkpoints/generator_epoch_199.pth')) discriminator.load_state_dict(torch.load('checkpoints/discriminator_epoch_199.pth'))
Generated images are saved in the generated_images/ directory. You can visualize them directly to assess the quality of the results.
You can evaluate the model using standard GAN evaluation metrics like Inception Score (IS) or FrΓ©chet Inception Distance (FID). These metrics help evaluate the diversity and quality of generated images compared to real images.
pip install pytorch-fid python -m pytorch_fid data/celeba generated_images/ #For FID