The objective of this project is to predict captions for images using KNN by leveraging image and caption embeddings from the CLIP model and Faiss for efficient nearest neighbor computation.
This project implements a K-Nearest Neighbors (KNN) algorithm for image captioning based on the paper A Distributed Representation Based Query Expansion Approach for Image Captioning While modern Vision-Language Models (VLMs) are widely used for image captioning, this project explores an earlier approach using KNN, which still performs surprisingly well.
Image and Caption Embeddings: Use pre-extracted embeddings from the CLIP model.
Find Nearest Neighbors: For each image, find k nearest neighbors using Faiss.
Query Vector Calculation: Compute a weighted sum of the captions from the nearest neighbors based on their cosine similarity with the query image.
Caption Prediction: The predicted caption is the closest to the query vector from the dataset captions.
Evaluation: BLEU score is used to evaluate the accuracy of the predicted captions compared to the ground truth captions.
# !gdown 1RwhwntZGZ9AX8XtGIDAcQD3ByTcUiOoO #image embeddings !gdown 18Z7QuFLJFL18xrxKWdqWgLcAgvGJDWJV
# !gdown 1b-4hU2Kp93r1nxMUGEgs1UbZov0OqFfW #caption embeddings !gdown 1BBXx3wmO6s4Zjn-E4nIWmcBIpGmyedPM
!wget http://images.cocodataset.org/zips/val2014.zip !unzip /content/val2014.zip !wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip !unzip /content/annotations_trainval2014.zip !pip install faiss-cpu
import torchvision.datasets as dset import torchvision.transforms as transforms from torch.utils.data import DataLoader import torch import torch.nn as nn import torch.nn.functional as F from nltk.translate import bleu_score import faiss import numpy as np
def get_transform(): transform = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), # convert the PIL Image to a tensor transforms.Normalize( (0.485, 0.456, 0.406), # normalize image for pre-trained model (0.229, 0.224, 0.225), ) ]) return transform coco_dset = dset.CocoCaptions(root = '/content/val2014', annFile = '/content/annotations/captions_val2014.json', transform=get_transform()) coco_dset1 = dset.CocoCaptions(root = '/content/val2014', annFile = '/content/annotations/captions_val2014.json') print('Number of samples: ', len(coco_dset)) img, target = coco_dset[3] # load 4th sample print("Image Size: ", img.shape) print(target)
ids = list(sorted(coco_dset.coco.imgs.keys())) captions = [] for i in range(len(ids)): captions.append([ele['caption'] for ele in coco_dset.coco.loadAnns(coco_dset.coco.getAnnIds(ids[i]))][:5]) #5 per image captions_np = np.array(captions) print('Captions:', captions_np.shape)
captions_flat = captions_np.flatten().tolist() print('Total captions:', len(captions_flat))
cap_path = '/content/coco_captions.npy' caption_embeddings = np.load(cap_path) print('Caption embeddings',caption_embeddings.shape)
img_path = '/content/coco_imgs.npy' image_embeddings = np.load(img_path) print('Image embeddings',image_embeddings.shape)
def accuracy(predict, real): ''' use bleu score as a measurement of accuracy :param predict: a list of predicted captions :param real: a list of actual descriptions :return: bleu accuracy ''' accuracy = 0 for i, pre in enumerate(predict): references = real[i] score = bleu_score.sentence_bleu(references, pre) accuracy += score return accuracy/len(predict)
import random def nearest_img_indices(i,k): sample_image_emb = image_embeddings[i] sample_image_cap = captions_np[i] if i=="IndexHNSWFlat": index = faiss.IndexFlatL2(512,k) else: index = faiss.IndexFlatL2(512) index.add(image_embeddings) query = sample_image_emb.reshape(1,-1) distances, indices = index.search(query,k+1) indices = indices[:,1:] indices = indices[0] return sample_image_emb,indices,sample_image_cap def cosine_similarity(u, v): dot_product = np.dot(u, v) norm_u = np.linalg.norm(u) norm_v = np.linalg.norm(v) similarity = dot_product / (norm_u * norm_v) return similarity def find_query_vector(indices,k,sample_image_emb): query_vector = np.zeros(512) # print(indices) for i in indices: for j in range(5): query_vector = query_vector + (cosine_similarity(sample_image_emb, image_embeddings[i])*caption_embeddings[i][j]) divisor = k*5 new_query_vector = query_vector / divisor return new_query_vector def find_best_caption(query_vector,indices): best_caption = "" max_sim = -10 for i in indices: for j in range(5): sim = cosine_similarity(query_vector,caption_embeddings[i][j]) if np.any(sim > max_sim): max_sim = sim best_caption = captions_np[i][j] return best_caption sample_img_indices = [random.randint(0, len(coco_dset)) for _ in range(1000)] def accuracy_v2(predict, real): lower_n_split = lambda x: x.lower().split() accuracy = 0 for i, pre in enumerate(predict): refs = real[i] score = bleu_score.sentence_bleu(list(map(lambda ref: lower_n_split(ref), refs)), lower_n_split(pre)) accuracy += score return accuracy/len(predict) # sample_img_indices=[0,1,2,3,4]
def find_caption(idx,k): predicted = [] real = [] for val in sample_img_indices: sample_image_emb,indices,sample_image_cap = nearest_img_indices(val,k) query_vector = find_query_vector(indices,k,sample_image_emb) best_caption = find_best_caption(query_vector,indices) predicted.append(best_caption) real.append(sample_image_cap) a = accuracy_v2(predicted, real) return a
import matplotlib.pyplot as plt acc = [] K = [1,2,3,4,5,6,7,8,9,10] for k in K: acc.append(find_caption("IndexFlatL2",k)) print(acc) plt.plot(K, acc, marker='o') plt.title('Accuracy vs. k') plt.xlabel('k') plt.ylabel('Accuracy') plt.grid(True) plt.show()
k = 7 acc1 = find_caption("IndexFlatL2",k) acc2 = find_caption("IndexFlatIP",k) acc3 = find_caption("IndexHNSWFlat",k) print("accuracy for IndexFlatL2", acc1) print("accuracy for IndexFlatIP", acc2) print("accuracy for IndexHNSWFlat", acc3)
import matplotlib.pyplot as plt n = 5 sample_img_indices = [random.randint(0, len(coco_dset)) for _ in range(n)] def Qualitative(idx,k): predicted = [] real = [] for val in sample_img_indices: sample_image_emb,indices,sample_image_cap = nearest_img_indices(val,k) query_vector = find_query_vector(indices,k,sample_image_emb) best_caption = find_best_caption(query_vector,indices) predicted.append(best_caption) real.append(sample_image_cap) img, target = coco_dset1[val] plt.imshow(img) plt.imshow(img) plt.axis('off') plt.show() print(captions_np[val]) print(best_caption) a = accuracy_v2(predicted, real) return a Qualitative("IndexFlatL2",7)
There are no models linked
There are no models linked