Image classification is a fundamental problem in computer vision, with applications ranging from agriculture to healthcare. In this project, we leverage transfer learning using the pre-trained VGG16 model to classify fruit images. Transfer learning enables us to utilize the rich feature representations learned on a large dataset (ImageNet) and fine-tune them for our specific dataset with relatively few images.
To build an accurate and robust deep learning model that can classify fruit images into multiple categories by fine-tuning the VGG16 architecture.
Language: Python 3.x
Frameworks: TensorFlow & Keras
Libraries: NumPy, Matplotlib, Scikit-learn, Scipy
Model: VGG16
Dataset: Fruits 360 (organized into train, validation, and test folders)
The Fruits 360 dataset includes hundreds of fruit categories with images in separate folders. We split the dataset into:
Training Set: Used for model learning with data augmentation.
Validation Set: Used to tune hyperparameters and monitor overfitting.
Test Set: Used to evaluate final model performance.
We used ImageDataGenerator with augmentation for the training set and simple rescaling for validation and test sets.
We built our model in the following steps:
import os import numpy as np import subprocess import zipfile import matplotlib.pyplot as plt from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.applications import VGG16 from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Flatten, Dropout, BatchNormalization from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping from tensorflow.keras.applications import MobileNetV2 from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, BatchNormalization, Dropout base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
for layer in base_model.layers: layer.trainable = False
3.Add custom layers:
model = Sequential([ base_model, GlobalAveragePooling2D(), Dense(256, activation='relu'), BatchNormalization(), Dropout(0.3), Dense(train_generator.num_classes, activation='softmax') ])
We compiled the model using:
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
Train the model, using callbacks to monitor the validation loss and adjust the learning rate or stop early to prevent overfitting.
import tensorflow as tf from tensorflow.keras.mixed_precision import set_global_policy from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping lr_scheduler = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, min_lr=1e-6, verbose=1) early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True) set_global_policy('float32') steps_per_epoch = 50 validation_steps = 25 history = model.fit( train_generator, epochs=5, validation_data=val_generator, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, callbacks=[lr_scheduler, early_stopping] )
Fine-tune by unfreezing a few layers in the VGG16 base model to allow learning on fruit-specific features.
Fine-tuning may take excess time on a CPU-based machine.
num_layers = len(base_model.layers) print(f"The base model has {num_layers} layers.") for layer in base_model.layers[-5:]: layer.trainable = True for layer in base_model.layers: if isinstance(layer, tf.keras.layers.BatchNormalization): layer.trainable = False model.compile( loss='categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), # Higher learning rate for faster convergence metrics=['accuracy'] ) history_fine = model.fit( train_generator, epochs=5, validation_data=val_generator, steps_per_epoch=steps_per_epoch, # Reduced steps per epoch validation_steps=validation_steps, # Reduced validation steps callbacks=[lr_scheduler, early_stopping] )
The model was evaluated on the test set:
test_loss, test_accuracy = model.evaluate(test_generator, steps=50) print(f"Test Accuracy: {test_accuracy:.2f}")
model.evaluate(test_generator): Evaluates the model on the test set and prints accuracy, giving a final measure of model performance.
Plots the training and validation accuracy and loss to understand the modelβs learning progress.
plt.plot: Plots the accuracy and loss for training and validation over epochs.
Visual comparison shows if the model is overfitting, underfitting, or learning effectively.
We ran the model on unseen fruit images and visualized predictions:
actual_count = Counter() predicted_count = Counter() def get_class_name_from_index(predicted_index, class_index_mapping): """Convert predicted index to class name.""" for class_name, index in class_index_mapping.items(): if index == predicted_index: return class_name return "Unknown" def visualize_prediction_with_actual(img_path, class_index_mapping): class_name = os.path.basename(os.path.dirname(img_path)) img = load_img(img_path, target_size=(64, 64)) img_array = img_to_array(img) / 255.0 img_array = np.expand_dims(img_array, axis=0) # Predict the class prediction = model.predict(img_array) predicted_index = np.argmax(prediction, axis=-1)[0] predicted_class_name = get_class_name_from_index(predicted_index, class_index_mapping) actual_count[class_name] += 1 predicted_count[predicted_class_name] += 1 plt.figure(figsize=(2, 2), dpi=100) plt.imshow(img) plt.title(f"Actual: {class_name}, Predicted: {predicted_class_name}") plt.axis('off') plt.show() class_index_mapping = train_generator.class_indices print("Class Index Mapping:", class_index_mapping) # Debugging: Check the mapping sample_images = [ 'fruits-360-original-size/fruits-360-original-size/Test/apple_braeburn_1/r0_11.jpg', 'fruits-360-original-size/fruits-360-original-size/Test/pear_1/r0_103.jpg', 'fruits-360-original-size/fruits-360-original-size/Test/cucumber_3/r0_103.jpg', ] for img_path in sample_images: visualize_prediction_with_actual(img_path, class_index_mapping)
visualize_prediction: Loads an image, preprocesses it, predicts its class, and displays it.
model.predict(img_array): Uses the trained model to make predictions on unseen images.
We noticed some incorrect predictions due to:
Class Similarity: Apples vs. Red Delicious Apples
Imbalanced Samples: Some classes had significantly fewer training examples
Limited Training: Fine-tuning fewer layers might not capture sufficient class-specific features.
Data Augmentation Impact: Aggressive augmentations may distort key features, reducing accuracy for specific images.
Through this project, we successfully built a fruit classification model using VGG16 and transfer learning. The model demonstrated strong generalization on unseen images with minimal training data. This project highlights the power of transfer learning in efficiently building computer vision solutions.
Dataset: Fruits 360 on Kaggle
Skills Network: Copyright Β© IBM Corporation