This project implements a state-of-the-art Generative Adversarial Network (GAN) using several advanced techniques like WGAN-GP, Self-Attention, Conditional GAN (cGAN) with multi-attribute control, Adaptive Discriminator Augmentation (ADA), and Exponential Moving Average (EMA) for generating high-quality images. The project is built using PyTorch and can be extended to handle various datasets.
Make sure you have the following libraries installed in your environment:
pip install torch torchvision tqdm numpy pillow tensorboard
Clone this repository to your local machine:
git clone https://github.com/yourusername/gan-project.git cd gan-project
Create the necessary directories:
mkdir checkpoints generated_images
python main.py
You can modify the hyperparameters in main.py:
After training, you can generate new images by loading the trained EMA generator model and feeding random noise vectors. The generated images will be saved to the generated_images/ folder.
from model import GeneratorWithSpectralNorm import torch from torchvision.utils import save_image # Load the generator model generator = GeneratorWithSpectralNorm(latent_dim=100, num_classes=10).to('cuda') generator.load_state_dict(torch.load('checkpoints/generator_epoch_199.pth')) generator.eval() # Generate images fixed_noise = torch.randn(64, 100, 1, 1).to('cuda') with torch.no_grad(): fake_images = generator(fixed_noise, labels) # Labels for cGAN can be passed here save_image(fake_images, "generated_images/final_output.png", normalize=True)
To resume training from a checkpoint or to use a pre-trained model, load the saved checkpoint like this:
generator.load_state_dict(torch.load('checkpoints/generator_epoch_199.pth')) discriminator.load_state_dict(torch.load('checkpoints/discriminator_epoch_199.pth'))
Generated images are saved in the generated_images/ directory. You can visualize them directly to assess the quality of the results.
You can evaluate the model using standard GAN evaluation metrics like Inception Score (IS) or Fréchet Inception Distance (FID). These metrics help evaluate the diversity and quality of generated images compared to real images.
pip install pytorch-fid python -m pytorch_fid data/celeba generated_images/ #For FID
There are no models linked
There are no models linked