๐ŸŽฏ Building an Object Detection Model for MNIST with PyTorch (COCO-style Dataset)

Object detection is one of the core tasks in computer vision, and in this post, we'll walk through a complete pipeline: from downloading and converting the MNIST dataset to COCO format, to building and training a custom object detector in PyTorch. We'll also visualize predictions, plot training history, and analyze failure cases!

This project demonstrates how to:

  • Convert a dataset to the COCO format.

  • Build a lightweight CNN-based object detector.

  • Train and validate the model with bounding box and class predictions.

  • Evaluate model performance visually and statistically.


๐Ÿ”ง 1. Step-by-Step: Convert MNIST to COCO Format

We’ll start by downloading MNIST and converting each digit into a COCO-style annotation: each image will have one bounding box and one label.

๐Ÿ—‚ mnist_download.py

import torch
import torchvision
from torchvision import datasets
import os
import json
from PIL import Image
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import random

def check_cuda_availability():
    print(f"PyTorch Version: {torch.__version__}")
   
    # Check if CUDA is available
    cuda_available = torch.cuda.is_available()
    print(f"\nCUDA Available: {cuda_available}")
   
    if cuda_available:
        # Get CUDA version
        print(f"CUDA Version: {torch.version.cuda}")
       
        # Get cuDNN version if available
        if hasattr(torch.backends, 'cudnn'):
            print(f"cuDNN Version: {torch.backends.cudnn.version()}")
            print(f"cuDNN Enabled: {torch.backends.cudnn.enabled}")
       
        # Get the current CUDA device
        current_device = torch.cuda.current_device()
        print(f"Current CUDA Device: {current_device}")
       
        # Get the name of the current CUDA device
        device_name = torch.cuda.get_device_name(current_device)
        print(f"CUDA Device Name: {device_name}")
       
        # Get the number of CUDA devices
        device_count = torch.cuda.device_count()
        print(f"Number of CUDA Devices: {device_count}")
       
        # Get CUDA capability
        cuda_capability = torch.cuda.get_device_capability(current_device)
        print(f"CUDA Capability: {cuda_capability}")
       
        # Get maximum memory allocated
        max_memory = torch.cuda.max_memory_allocated(current_device)
        print(f"Maximum Memory Allocated: {max_memory / 1024**2:.2f} MB")
    else:
        print("CUDA is not available. PyTorch will run on CPU only.")

def download_mnist_coco(root_dir='mnist_coco', split='train'):
    """
    Download MNIST dataset and convert it to COCO format
    Args:
        root_dir: Directory to save the dataset
        split: 'train' or 'test'
    Returns:
        None (saves dataset to disk)
    """
    # Create directories
    os.makedirs(os.path.join(root_dir, 'images'), exist_ok=True)
    os.makedirs(os.path.join(root_dir, 'annotations'), exist_ok=True)
   
    # Download MNIST dataset
    mnist_dataset = datasets.MNIST(
        root='./mnist_raw',
        train=(split == 'train'),
        download=True
    )
   
    # Initialize COCO format dictionary
    coco_format = {
        "images": [],
        "annotations": [],
        "categories": [
            {"id": i, "name": str(i), "supercategory": "digit"}
            for i in range(10)
        ]
    }
   
    # Convert MNIST to COCO format
    annotation_id = 1
   
    print(f"Converting {split} split to COCO format...")
    for idx in tqdm(range(len(mnist_dataset))):
        # Get image and label
        image, label = mnist_dataset[idx]
       
        # Create image filename
        image_filename = f"{split}_{idx:06d}.png"
        image_path = os.path.join(root_dir, 'images', image_filename)
       
        # Save image
        image.save(image_path)
       
        # Get image dimensions
        width, height = image.size
       
        # Add image info to COCO format
        image_info = {
            "id": idx,
            "file_name": image_filename,
            "width": width,
            "height": height,
        }
        coco_format["images"].append(image_info)
       
        # Convert image to numpy array to get bounding box
        image_array = np.array(image)
        rows = np.any(image_array, axis=1)
        cols = np.any(image_array, axis=0)
        ymin, ymax = np.where(rows)[0][[0, -1]]
        xmin, xmax = np.where(cols)[0][[0, -1]]
       
        # Add annotation info to COCO format
        annotation_info = {
            "id": annotation_id,
            "image_id": idx,
            "category_id": int(label),
            "bbox": [int(xmin), int(ymin), int(xmax - xmin), int(ymax - ymin)],
            "area": int((xmax - xmin) * (ymax - ymin)),
            "iscrowd": 0
        }
        coco_format["annotations"].append(annotation_info)
        annotation_id += 1
   
    # Save annotations to JSON file
    annotation_file = os.path.join(root_dir, 'annotations', f'instances_{split}.json')
    with open(annotation_file, 'w') as f:
        json.dump(coco_format, f)
   
    print(f"Dataset saved in COCO format at: {root_dir}")
    print(f"Total images: {len(coco_format['images'])}")
    print(f"Total annotations: {len(coco_format['annotations'])}")

def display_mnist_coco(root_dir='mnist_coco', split='train', num_images=4):
    """
    Display random images from the MNIST COCO dataset with their bounding boxes
    Args:
        root_dir: Directory containing the dataset
        split: 'train' or 'test'
        num_images: Number of images to display
    """
    # Load annotations
    annotation_file = os.path.join(root_dir, 'annotations', f'instances_{split}.json')
    with open(annotation_file, 'r') as f:
        coco_data = json.load(f)
   
    # Create image id to annotations mapping
    image_to_anns = {}
    for ann in coco_data['annotations']:
        image_id = ann['image_id']
        if image_id not in image_to_anns:
            image_to_anns[image_id] = []
        image_to_anns[image_id].append(ann)
   
    # Select random images
    selected_images = random.sample(coco_data['images'], min(num_images, len(coco_data['images'])))
   
    # Create subplot grid
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    axes = axes.ravel()
   
    for idx, img_info in enumerate(selected_images):
        if idx >= num_images:
            break
           
        # Load image
        image_path = os.path.join(root_dir, 'images', img_info['file_name'])
        image = Image.open(image_path)
       
        # Display image
        axes[idx].imshow(image, cmap='gray')
       
        # Get annotations for this image
        annotations = image_to_anns.get(img_info['id'], [])
       
        # Draw bounding boxes and labels
        for ann in annotations:
            bbox = ann['bbox']  # [x, y, width, height]
            category_id = ann['category_id']
           
            # Create rectangle patch
            rect = plt.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3],
                               fill=False, edgecolor='red', linewidth=2)
            axes[idx].add_patch(rect)
           
            # Add label
            axes[idx].text(bbox[0], bbox[1]-5, f'Digit: {category_id}',
                         color='red', fontsize=10, bbox=dict(facecolor='white', alpha=0.7))
       
        axes[idx].set_title(f'{split.capitalize()} Image {img_info["id"]}')
        axes[idx].axis('off')
   
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    check_cuda_availability()
    # Example usage
    # download_mnist_coco(root_dir='mnist_coco', split='train')  # Download training set
    # download_mnist_coco(root_dir='mnist_coco', split='test')   # Download test set
   
    # Display random images from both train and test sets
    print("\nDisplaying training images:")
    display_mnist_coco(split='train')
   
    print("\nDisplaying test images:")
    display_mnist_coco(split='test')

Run python mnist_download.py to create the dataset directory mnist_coco/ containing images/ and annotations/.


๐Ÿง  2. Designing a Simple Object Detection Model

Now, we’ll build a simple model that predicts both a bounding box and digit class for each image. The architecture includes:

  • Convolutional layers to extract features.

  • A shared fully connected layer.

  • Two heads: one for bounding box regression and one for classification.


๐Ÿ’ก ObjectDetection.py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import os
import json
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from datetime import datetime, timedelta

class MNISTCOCODataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        """
        Args:
            root_dir: Directory with all the images and annotations
            split: 'train' or 'test'
            transform: Optional transform to be applied on images
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
       
        # Load annotations
        ann_file = os.path.join(root_dir, 'annotations', f'instances_{split}.json')
        with open(ann_file, 'r') as f:
            self.coco_data = json.load(f)
           
        # Create image_id to annotation mapping
        self.image_to_ann = {}
        for ann in self.coco_data['annotations']:
            self.image_to_ann[ann['image_id']] = ann
           
        self.images = self.coco_data['images']
       
    def __len__(self):
        return len(self.images)
   
    def __getitem__(self, idx):
        # Load image
        img_info = self.images[idx]
        img_path = os.path.join(self.root_dir, 'images', img_info['file_name'])
        image = Image.open(img_path).convert('L')  # Convert to grayscale
       
        # Get annotation
        ann = self.image_to_ann[img_info['id']]
        bbox = torch.tensor(ann['bbox'], dtype=torch.float32)  # [x, y, width, height]
        label = torch.tensor(ann['category_id'], dtype=torch.long)
       
        # Normalize bbox coordinates
        bbox[0] /= image.width  # x
        bbox[1] /= image.height  # y
        bbox[2] /= image.width  # width
        bbox[3] /= image.height  # height
       
        if self.transform:
            image = self.transform(image)
       
        return image, bbox, label

class SimpleDetector(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleDetector, self).__init__()
       
        # CNN Feature Extractor
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
           
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
           
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
       
        # Fully connected layers
        self.classifier = nn.Sequential(
            nn.Linear(128 * 3 * 3, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
       
        # Separate heads for bbox regression and classification
        self.bbox_head = nn.Linear(512, 4)  # 4 for [x, y, width, height]
        self.class_head = nn.Linear(512, num_classes)
       
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
       
        bbox = self.bbox_head(x)
        bbox = torch.sigmoid(bbox)  # Normalize bbox predictions to [0, 1]
       
        cls_score = self.class_head(x)
       
        return bbox, cls_score

def format_time(seconds):
    """Convert seconds to human readable time format"""
    return str(timedelta(seconds=int(seconds)))

def train_model(model, train_loader, val_loader, num_epochs=10, device='cuda', accuracy_threshold=98.0):
    criterion_bbox = nn.MSELoss()
    criterion_cls = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
   
    model = model.to(device)
    best_loss = float('inf')
    best_accuracy = 0.0
   
    # For storing metrics
    metrics_history = {
        'train_bbox_loss': [], 'train_cls_loss': [], 'train_total_loss': [],
        'val_bbox_loss': [], 'val_cls_loss': [], 'val_total_loss': [],
        'train_accuracy': [], 'val_accuracy': []
    }
   
    total_start_time = time.time()
    total_batches = len(train_loader) * num_epochs
    batches_done = 0
   
    print("\n=== Training Start ===")
    print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Total epochs: {num_epochs}")
    print(f"Batches per epoch: {len(train_loader)}")
    print(f"Total batches: {total_batches}")
    print(f"Early stopping accuracy threshold: {accuracy_threshold}%")
    print("=" * 50)
   
    for epoch in range(num_epochs):
        epoch_start = time.time()
        model.train()
        train_loss = 0.0
        train_bbox_loss = 0.0
        train_cls_loss = 0.0
        correct_preds = 0
        total_samples = 0
       
        # Calculate elapsed and estimated time
        elapsed_time = time.time() - total_start_time
        if batches_done > 0:
            time_per_batch = elapsed_time / batches_done
            eta = time_per_batch * (total_batches - batches_done)
        else:
            eta = 0
           
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Elapsed time: {format_time(elapsed_time)}")
        print(f"Estimated time remaining: {format_time(eta)}")
        print(f"Expected completion: {datetime.now() + timedelta(seconds=int(eta))}")
       
        # Training progress bar
        train_pbar = tqdm(train_loader,
                         desc=f'Train [{format_time(elapsed_time)} < {format_time(eta)}]',
                         leave=False, ncols=100)
       
        for batch_idx, (images, bboxes, labels) in enumerate(train_pbar):
            batch_start = time.time()
           
            images = images.to(device)
            bboxes = bboxes.to(device)
            labels = labels.to(device)
           
            optimizer.zero_grad()
           
            pred_bboxes, pred_cls = model(images)
            loss_bbox = criterion_bbox(pred_bboxes, bboxes)
            loss_cls = criterion_cls(pred_cls, labels)
            loss = loss_bbox + loss_cls
           
            loss.backward()
            optimizer.step()
           
            # Update metrics
            train_loss += loss.item()
            train_bbox_loss += loss_bbox.item()
            train_cls_loss += loss_cls.item()
           
            # Calculate accuracy
            pred_labels = torch.argmax(pred_cls, dim=1)
            correct_preds += (pred_labels == labels).sum().item()
            total_samples += labels.size(0)
           
            # Update batch counter
            batches_done += 1
           
            # Calculate speeds and times
            batch_time = time.time() - batch_start
            images_per_sec = images.shape[0] / batch_time
           
            # Update progress bar with detailed metrics
            train_pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'bbox': f'{loss_bbox.item():.4f}',
                'cls': f'{loss_cls.item():.4f}',
                'img/s': f'{images_per_sec:.1f}'
            })
           
            # Print detailed metrics every 50 batches
            if (batch_idx + 1) % 50 == 0:
                elapsed = time.time() - total_start_time
                progress = batches_done / total_batches
                eta = (elapsed / progress) - elapsed if progress > 0 else 0
               
                print(f"\nBatch {batch_idx + 1}/{len(train_loader)} Statistics:")
                print(f"  Time elapsed: {format_time(elapsed)}")
                print(f"  Time remaining: {format_time(eta)}")
                print(f"  Processing speed: {images_per_sec:.1f} images/sec")
                print(f"  Loss: {loss.item():.4f}")
                print(f"  Bbox Loss: {loss_bbox.item():.4f}")
                print(f"  Class Loss: {loss_cls.item():.4f}")
                print(f"  Accuracy: {100 * correct_preds / total_samples:.2f}%")
       
        train_pbar.close()
       
        # Calculate average training metrics
        avg_train_loss = train_loss / len(train_loader)
        avg_train_bbox_loss = train_bbox_loss / len(train_loader)
        avg_train_cls_loss = train_cls_loss / len(train_loader)
        train_accuracy = 100 * correct_preds / total_samples
       
        # Validation
        model.eval()
        val_loss = 0.0
        val_bbox_loss = 0.0
        val_cls_loss = 0.0
        val_correct_preds = 0
        val_total_samples = 0
       
        # Validation progress bar
        val_pbar = tqdm(val_loader,
                       desc=f'Val [Epoch {epoch+1}/{num_epochs}]',
                       leave=False, ncols=100)
       
        val_start = time.time()
        with torch.no_grad():
            for images, bboxes, labels in val_pbar:
                images = images.to(device)
                bboxes = bboxes.to(device)
                labels = labels.to(device)
               
                pred_bboxes, pred_cls = model(images)
                loss_bbox = criterion_bbox(pred_bboxes, bboxes)
                loss_cls = criterion_cls(pred_cls, labels)
                loss = loss_bbox + loss_cls
               
                # Update metrics
                val_loss += loss.item()
                val_bbox_loss += loss_bbox.item()
                val_cls_loss += loss_cls.item()
               
                # Calculate accuracy
                pred_labels = torch.argmax(pred_cls, dim=1)
                val_correct_preds += (pred_labels == labels).sum().item()
                val_total_samples += labels.size(0)
               
                val_pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'bbox': f'{loss_bbox.item():.4f}',
                    'cls': f'{loss_cls.item():.4f}'
                })
       
        val_pbar.close()
        val_time = time.time() - val_start
       
        # Calculate average validation metrics
        avg_val_loss = val_loss / len(val_loader)
        avg_val_bbox_loss = val_bbox_loss / len(val_loader)
        avg_val_cls_loss = val_cls_loss / len(val_loader)
        val_accuracy = 100 * val_correct_preds / val_total_samples
       
        # Update metrics history
        metrics_history['train_bbox_loss'].append(avg_train_bbox_loss)
        metrics_history['train_cls_loss'].append(avg_train_cls_loss)
        metrics_history['train_total_loss'].append(avg_train_loss)
        metrics_history['val_bbox_loss'].append(avg_val_bbox_loss)
        metrics_history['val_cls_loss'].append(avg_val_cls_loss)
        metrics_history['val_total_loss'].append(avg_val_loss)
        metrics_history['train_accuracy'].append(train_accuracy)
        metrics_history['val_accuracy'].append(val_accuracy)
       
        # Print epoch summary with timing information
        epoch_time = time.time() - epoch_start
        elapsed_total = time.time() - total_start_time
        eta_total = (elapsed_total / (epoch + 1)) * (num_epochs - (epoch + 1))
       
        print(f'\nEpoch {epoch+1}/{num_epochs} Complete:')
        print(f'Time Statistics:')
        print(f'  Epoch time: {format_time(epoch_time)}')
        print(f'  Training time: {format_time(epoch_time - val_time)}')
        print(f'  Validation time: {format_time(val_time)}')
        print(f'  Total elapsed: {format_time(elapsed_total)}')
        print(f'  Estimated remaining: {format_time(eta_total)}')
        print(f'Training Metrics:')
        print(f'  Total Loss: {avg_train_loss:.4f}')
        print(f'  Bbox Loss: {avg_train_bbox_loss:.4f}')
        print(f'  Class Loss: {avg_train_cls_loss:.4f}')
        print(f'  Accuracy: {train_accuracy:.2f}%')
        print(f'Validation Metrics:')
        print(f'  Total Loss: {avg_val_loss:.4f}')
        print(f'  Bbox Loss: {avg_val_bbox_loss:.4f}')
        print(f'  Class Loss: {avg_val_cls_loss:.4f}')
        print(f'  Accuracy: {val_accuracy:.2f}%')
       
        # After calculating val_accuracy, add early stopping check
        if val_accuracy >= accuracy_threshold:
            print(f"\n=== Early Stopping ===")
            print(f"Validation accuracy {val_accuracy:.2f}% reached threshold {accuracy_threshold}%")
            print(f"Stopping training at epoch {epoch+1}/{num_epochs}")
           
            # Save final model if it's the best
            if val_accuracy > best_accuracy:
                best_accuracy = val_accuracy
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_loss': best_loss,
                    'best_accuracy': best_accuracy,
                    'metrics_history': metrics_history,
                    'training_time': time.time() - total_start_time
                }, 'best_model.pth')
                print(f'Final model saved with accuracy: {best_accuracy:.2f}%')
           
            break
       
        # Update best accuracy
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
       
        # Save best model (modified to include accuracy)
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_loss': best_loss,
                'best_accuracy': best_accuracy,
                'metrics_history': metrics_history,
                'training_time': time.time() - total_start_time
            }, 'best_model.pth')
            print(f'New best model saved! (Val Loss: {best_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%)')
       
        print('=' * 80)
   
    total_time = time.time() - total_start_time
    print('\n=== Training Complete ===')
    print(f'Total training time: {format_time(total_time)}')
    print(f'End time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')
    print(f'Best validation loss: {best_loss:.4f}')
    print(f'Best validation accuracy: {best_accuracy:.2f}%')
    print('=' * 50)
   
    return metrics_history

def visualize_predictions(model, test_loader, device, num_images=4):
    model.eval()
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    axes = axes.ravel()
   
    with torch.no_grad():
        for idx, (images, bboxes, labels) in enumerate(test_loader):
            if idx >= num_images:
                break
               
            images = images.to(device)
            pred_bboxes, pred_cls = model(images)
           
            # Get predictions
            pred_bbox = pred_bboxes[0].cpu().numpy()
            pred_label = torch.argmax(pred_cls[0]).item()
            true_bbox = bboxes[0].numpy()
            true_label = labels[0].item()
           
            # Display image
            img = images[0].cpu().squeeze()
            axes[idx].imshow(img, cmap='gray')
           
            # Draw predicted bbox (red) and true bbox (green)
            h, w = img.shape
            pred_rect = plt.Rectangle((pred_bbox[0]*w, pred_bbox[1]*h),
                                   pred_bbox[2]*w, pred_bbox[3]*h,
                                   fill=False, edgecolor='red', linewidth=2)
            true_rect = plt.Rectangle((true_bbox[0]*w, true_bbox[1]*h),
                                    true_bbox[2]*w, true_bbox[3]*h,
                                    fill=False, edgecolor='green', linewidth=2)
           
            axes[idx].add_patch(pred_rect)
            axes[idx].add_patch(true_rect)
            axes[idx].set_title(f'Pred: {pred_label}, True: {true_label}')
            axes[idx].axis('off')
   
    plt.tight_layout()
    plt.show()

def plot_training_history(metrics_history):
    """Plot training and validation metrics history"""
    plt.figure(figsize=(15, 5))
   
    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(metrics_history['train_total_loss'], label='Train Loss')
    plt.plot(metrics_history['val_total_loss'], label='Val Loss')
    plt.plot(metrics_history['train_bbox_loss'], label='Train Bbox Loss')
    plt.plot(metrics_history['train_cls_loss'], label='Train Cls Loss')
    plt.title('Training and Validation Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
   
    # Plot accuracies
    plt.subplot(1, 2, 2)
    plt.plot(metrics_history['train_accuracy'], label='Train Accuracy')
    plt.plot(metrics_history['val_accuracy'], label='Val Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
   
    plt.tight_layout()
    plt.show()

def load_and_visualize_model(model_path='best_model.pth', data_root='mnist_coco', num_images=4):
    """
    Load a trained model and visualize its predictions
    Args:
        model_path: Path to the saved model checkpoint
        data_root: Directory containing the dataset
        num_images: Number of images to visualize
    """
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
   
    # Create model and load weights
    model = SimpleDetector()
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
   
    # Print model information
    print("\n=== Model Information ===")
    print(f"Checkpoint from epoch: {checkpoint['epoch'] + 1}")
    print(f"Best validation loss: {checkpoint['best_loss']:.4f}")
    print(f"Best validation accuracy: {checkpoint['best_accuracy']:.2f}%")
    print(f"Total training time: {format_time(checkpoint['training_time'])}")
   
    # Create test dataset and loader
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
    ])
   
    # Create test dataset
    test_dataset = MNISTCOCODataset(root_dir=data_root, split='test', transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
   
    # Visualize predictions
    print("\n=== Visualizing Predictions ===")
    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    axes = axes.ravel()
   
    with torch.no_grad():
        for idx, (images, bboxes, labels) in enumerate(test_loader):
            if idx >= num_images:
                break
               
            images = images.to(device)
            pred_bboxes, pred_cls = model(images)
           
            # Get predictions
            pred_bbox = pred_bboxes[0].cpu().numpy()
            pred_label = torch.argmax(pred_cls[0]).item()
            true_bbox = bboxes[0].numpy()
            true_label = labels[0].item()
           
            # Display image
            img = images[0].cpu().squeeze()
            axes[idx].imshow(img, cmap='gray')
           
            # Draw predicted bbox (red) and true bbox (green)
            h, w = img.shape
            pred_rect = plt.Rectangle((pred_bbox[0]*w, pred_bbox[1]*h),
                                    pred_bbox[2]*w, pred_bbox[3]*h,
                                    fill=False, edgecolor='red', linewidth=2,
                                    label='Prediction')
            true_rect = plt.Rectangle((true_bbox[0]*w, true_bbox[1]*h),
                                    true_bbox[2]*w, true_bbox[3]*h,
                                    fill=False, edgecolor='green', linewidth=2,
                                    label='Ground Truth')
           
            axes[idx].add_patch(pred_rect)
            axes[idx].add_patch(true_rect)
           
            # Add title with predictions
            axes[idx].set_title(f'Pred: {pred_label} (True: {true_label})\n' +
                              f'{"✓" if pred_label == true_label else "✗"}',
                              color='green' if pred_label == true_label else 'red')
           
            # Add legend
            if idx == 0:
                axes[idx].legend()
           
            axes[idx].axis('off')
   
    plt.suptitle('Model Predictions on Test Set\nRed: Predicted Box, Green: True Box',
                 fontsize=12, y=0.95)
    plt.tight_layout()
    plt.show()
   
    # Calculate and display test metrics
    print("\n=== Test Set Metrics ===")
    correct = 0
    total = 0
    bbox_error = 0
   
    with torch.no_grad():
        for images, bboxes, labels in tqdm(test_loader, desc="Evaluating"):
            images = images.to(device)
            bboxes = bboxes.to(device)
            labels = labels.to(device)
           
            pred_bboxes, pred_cls = model(images)
            pred_labels = torch.argmax(pred_cls, dim=1)
           
            # Classification accuracy
            correct += (pred_labels == labels).sum().item()
            total += labels.size(0)
           
            # Bounding box error (MSE)
            bbox_error += nn.MSELoss()(pred_bboxes, bboxes).item()
   
    test_accuracy = 100 * correct / total
    avg_bbox_error = bbox_error / len(test_loader)
   
    print(f"Test Accuracy: {test_accuracy:.2f}%")
    print(f"Average Bounding Box Error: {avg_bbox_error:.4f}")
   
    return model

def analyze_failure_cases(model_path='best_model.pth', data_root='mnist_coco', num_failures=8):
    """
    Analyze and display failure cases where the model made incorrect predictions
    Args:
        model_path: Path to the saved model checkpoint
        data_root: Directory containing the dataset
        num_failures: Number of failure cases to display
    """
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
   
    # Create model and load weights
    model = SimpleDetector()
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
   
    # Create test dataset and loader
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
    ])
   
    test_dataset = MNISTCOCODataset(root_dir=data_root, split='test', transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)
   
    # Collect failure cases
    failure_cases = []
    confusion_matrix = torch.zeros(10, 10)  # 10 classes for MNIST
    bbox_errors = []
   
    print("Analyzing model predictions...")
    with torch.no_grad():
        for images, bboxes, labels in tqdm(test_loader, desc="Finding failures"):
            images = images.to(device)
            bboxes = bboxes.to(device)
            labels = labels.to(device)
           
            pred_bboxes, pred_cls = model(images)
            pred_labels = torch.argmax(pred_cls, dim=1)
           
            # Calculate bbox error
            bbox_error = nn.MSELoss()(pred_bboxes, bboxes).item()
           
            # Update confusion matrix
            confusion_matrix[labels.item(), pred_labels.item()] += 1
           
            # Check if prediction is wrong
            if pred_labels.item() != labels.item() or bbox_error > 0.1:  # threshold for significant bbox error
                failure_cases.append({
                    'image': images[0].cpu(),
                    'true_label': labels.item(),
                    'pred_label': pred_labels.item(),
                    'true_bbox': bboxes[0].cpu().numpy(),
                    'pred_bbox': pred_bboxes[0].cpu().numpy(),
                    'bbox_error': bbox_error
                })
                bbox_errors.append(bbox_error)
   
    # Sort failure cases by bbox error
    failure_cases.sort(key=lambda x: x['bbox_error'], reverse=True)
   
    # Calculate statistics
    total_failures = len(failure_cases)
    total_samples = len(test_loader)
    failure_rate = (total_failures / total_samples) * 100
   
    print("\n=== Failure Analysis ===")
    print(f"Total samples tested: {total_samples}")
    print(f"Number of failures: {total_failures}")
    print(f"Failure rate: {failure_rate:.2f}%")
   
    if total_failures == 0:
        print("No failures found!")
        return
   
    # Display worst failure cases
    num_cols = 4
    num_rows = (min(num_failures, len(failure_cases)) + num_cols - 1) // num_cols
    fig = plt.figure(figsize=(15, 4 * num_rows))
   
    for idx, case in enumerate(failure_cases[:num_failures]):
        plt.subplot(num_rows, num_cols, idx + 1)
       
        # Display image
        img = case['image'].squeeze()
        plt.imshow(img, cmap='gray')
       
        # Draw bounding boxes
        h, w = img.shape
        pred_bbox = case['pred_bbox']
        true_bbox = case['true_bbox']
       
        # Predicted bbox in red
        pred_rect = plt.Rectangle((pred_bbox[0]*w, pred_bbox[1]*h),
                                pred_bbox[2]*w, pred_bbox[3]*h,
                                fill=False, edgecolor='red', linewidth=2,
                                label='Prediction')
        # True bbox in green
        true_rect = plt.Rectangle((true_bbox[0]*w, true_bbox[1]*h),
                                true_bbox[2]*w, true_bbox[3]*h,
                                fill=False, edgecolor='green', linewidth=2,
                                label='Ground Truth')
       
        plt.gca().add_patch(pred_rect)
        plt.gca().add_patch(true_rect)
       
        # Add title with error information
        plt.title(f'True: {case["true_label"]}, Pred: {case["pred_label"]}\n' +
                 f'BBox Error: {case["bbox_error"]:.4f}',
                 color='red')
       
        if idx == 0:
            plt.legend()
       
        plt.axis('off')
   
    plt.suptitle('Worst Failure Cases\nRed: Predicted Box, Green: True Box',
                 fontsize=14, y=1.02)
    plt.tight_layout()
   
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    confusion_matrix_normalized = confusion_matrix / confusion_matrix.sum(dim=1, keepdim=True)
    plt.imshow(confusion_matrix_normalized.numpy(), cmap='YlOrRd')
    plt.colorbar()
    plt.title('Confusion Matrix (Normalized)')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
   
    # Add text annotations to the confusion matrix
    for i in range(10):
        for j in range(10):
            plt.text(j, i, f'{confusion_matrix_normalized[i, j]:.2f}',
                    ha='center', va='center')
   
    # Plot bbox error distribution
    if bbox_errors:
        plt.figure(figsize=(10, 5))
        plt.hist(bbox_errors, bins=50, edgecolor='black')
        plt.title('Distribution of Bounding Box Errors in Failure Cases')
        plt.xlabel('Bounding Box Error (MSE)')
        plt.ylabel('Count')
   
    plt.show()
   
    # Print detailed statistics
    print("\nFailure Type Analysis:")
    print(f"Classification errors: {sum(case['true_label'] != case['pred_label'] for case in failure_cases)}")
    print(f"Significant bbox errors (>0.1): {sum(case['bbox_error'] > 0.1 for case in failure_cases)}")
   
    if bbox_errors:
        print(f"\nBounding Box Error Statistics:")
        print(f"Mean error: {np.mean(bbox_errors):.4f}")
        print(f"Median error: {np.median(bbox_errors):.4f}")
        print(f"Max error: {np.max(bbox_errors):.4f}")
        print(f"Min error: {np.min(bbox_errors):.4f}")

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = torch.device('cpu')
    print(f"Using device: {device}")
   
    # Data transforms
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
    ])
   
    # Create datasets
    train_dataset = MNISTCOCODataset(root_dir='mnist_coco', split='train', transform=transform)
    test_dataset = MNISTCOCODataset(root_dir='mnist_coco', split='test', transform=transform)
   
    # Create data loaders
    train_size = int(0.8 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size])
   
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)
    test_loader = DataLoader(test_dataset, batch_size=1)
   
    # Create and train model with accuracy threshold
    model = SimpleDetector()
    metrics_history = train_model(model, train_loader, val_loader,
                                num_epochs=10,
                                device=device,
                                accuracy_threshold=98.0)  # Set accuracy threshold to 98%
   
    # Plot training history
    plot_training_history(metrics_history)
   
    # Load best model and visualize predictions
    checkpoint = torch.load('best_model.pth')
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\nBest model performance:")
    print(f"Validation Loss: {checkpoint['best_loss']:.4f}")
    print(f"Validation Accuracy: {checkpoint['best_accuracy']:.2f}%")
    visualize_predictions(model, test_loader, device)

if __name__ == "__main__":
    # Either run training or analyze existing model
    if not os.path.exists('best_model.pth'):
        # Run training
        main()
    else:
        # Load and visualize model, then analyze failures
        print("\nVisualizing model predictions...")
        load_and_visualize_model('best_model.pth')
       
        print("\nAnalyzing failure cases...")
        analyze_failure_cases('best_model.pth')

Save this file as ObjectDetection.py. This file includes:

  • MNISTCOCODataset: Custom dataset loader.

  • SimpleDetector: A CNN with two heads (bbox and class).

  • train_model: Training loop with detailed logging and early stopping.

  • visualize_predictions: Shows true and predicted boxes.

  • analyze_failure_cases: Displays worst predictions and computes confusion matrix.

  • main: Runs training, saves the model, plots results.


▶️ 3. How to Run

Once you've downloaded the dataset and written both scripts:

๐Ÿ” Option A: Train the Model

python ObjectDetection.py

If best_model.pth doesn’t exist, this will train a model from scratch and save it once it reaches your defined accuracy threshold (default: 98%).

๐Ÿ“Š Option B: Visualize and Analyze

If best_model.pth exists:

python ObjectDetection.py

It will:

  • Load the trained model

  • Visualize predictions vs. ground truth

  • Show failure cases with bounding box errors and a confusion matrix


๐Ÿ“ˆ Example Output

  • Training Graphs (loss and accuracy)

  • Visual Predictions: Red = predicted box, Green = ground truth

  • Failure Cases: Includes misclassified digits or inaccurate bounding boxes

  • Confusion Matrix: A heatmap showing which digits were confused

  • BBox Error Distribution: Histogram of errors


๐Ÿงช Why This Project is Valuable

This hands-on mini-project is perfect if you:

  • Want to learn object detection basics.

  • Are preparing to work with custom datasets in COCO format.

  • Want to test the full ML pipeline from data preparation to evaluation.

You’ll gain practical knowledge of:

  • COCO JSON annotation structure

  • Custom PyTorch datasets and dataloaders

  • Dual-loss training (regression + classification)

  • Visual debugging of model performance


✅ Requirements

Install the dependencies:

pip install torch torchvision matplotlib tqdm pillow

๐Ÿ“ฆ Directory Structure

.
├── mnist_download.py
├── ObjectDetection.py
├── mnist_coco/
│   ├── images/
│   └── annotations/
└── best_model.pth (generated after training)

๐Ÿ Final Words

This post walks you through a complete object detection pipeline using MNIST digits, giving you a reusable framework for other small detection tasks. You’ve learned how to:

  • Convert a dataset to COCO format.

  • Train a detection model with bounding box and class loss.

  • Visualize results and analyze performance issues.

Use this as a base for expanding into more complex datasets or refining your object detector with new architectures like SSD or YOLO!











Comments

Popular posts from this blog

๐Ÿ“ Fun Fruit Math Game for Kids – Learn Multiplication & Division with Smiles!

Visualize Permutations and Combinations with Fruits!

๐Ÿž️ River Distance Explorer – Learn Trigonometry Through a Fun Interactive Game