This repository hosts the implementation of a Vision Transformer (ViT) model trained on the CIFAR-10 dataset. Vision Transformers have revolutionized computer vision by adapting transformer models, traditionally used in NLP, for image-related tasks.
Access the complete project on GitHub: V-transformer
The goal of this project is to explore the application of Vision Transformers on the CIFAR-10 dataset and evaluate their performance compared to traditional convolutional neural networks (CNNs).
Ensure you have the following installed:
Install the required libraries:
pip install -r requirements.txt
The CIFAR-10 dataset is automatically downloaded using torchvision.datasets
.
The Vision Transformer model is built as follows:
import torch import torch.nn as nn class VisionTransformer(nn.Module): def __init__(self, img_size, patch_size, num_classes, dim, depth, heads, mlp_dim): super(VisionTransformer, self).__init__() # Patch embedding layer self.patch_embed = nn.Linear(patch_size**2 * 3, dim) # Positional encoding self.pos_embedding = nn.Parameter(torch.randn(1, (img_size // patch_size)**2 + 1, dim)) # Transformer layers self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim), num_layers=depth ) # Classification head self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) self.mlp_head = nn.Linear(dim, num_classes) def forward(self, x, output_attentions=False): #Calculate the embedding output embedding_output = self.embedding(x) #Calculate the encoder's output encoder_output, all_attentions = self.encoder(embedding_output, output_attentions=output_attentions) #Calculate the logits, taking the Classify token's output as feature for classfication logits = self.classifier(encoder_output[:, 0]) #Return the logits and the attention probabailities if not output_attentions: return(logits, None) else: return(logits, all_attentions)
The ViT outperformed baseline CNN models for CIFAR-10 in terms of accuracy, demonstrating the effectiveness of transformer-based architectures for vision tasks.
To train the model, run:
python train.py
Evaluate the trained model:
python evaluate.py --checkpoint <path_to_checkpoint>
Feel free to contribute or raise issues on the GitHub repository.