๐ฏ Building an Object Detection Model for MNIST with PyTorch (COCO-style Dataset)
Object detection is one of the core tasks in computer vision, and in this post, we'll walk through a complete pipeline: from downloading and converting the MNIST dataset to COCO format, to building and training a custom object detector in PyTorch. We'll also visualize predictions, plot training history, and analyze failure cases!
This project demonstrates how to:
-
Convert a dataset to the COCO format.
-
Build a lightweight CNN-based object detector.
-
Train and validate the model with bounding box and class predictions.
-
Evaluate model performance visually and statistically.
๐ง 1. Step-by-Step: Convert MNIST to COCO Format
We’ll start by downloading MNIST and converting each digit into a COCO-style annotation: each image will have one bounding box and one label.
๐ mnist_download.py
import torchimport torchvisionfrom torchvision import datasetsimport osimport jsonfrom PIL import Imageimport numpy as npfrom tqdm import tqdmimport matplotlib.pyplot as pltimport random
def check_cuda_availability(): print(f"PyTorch Version: {torch.__version__}") # Check if CUDA is available cuda_available = torch.cuda.is_available() print(f"\nCUDA Available: {cuda_available}") if cuda_available: # Get CUDA version print(f"CUDA Version: {torch.version.cuda}") # Get cuDNN version if available if hasattr(torch.backends, 'cudnn'): print(f"cuDNN Version: {torch.backends.cudnn.version()}") print(f"cuDNN Enabled: {torch.backends.cudnn.enabled}") # Get the current CUDA device current_device = torch.cuda.current_device() print(f"Current CUDA Device: {current_device}") # Get the name of the current CUDA device device_name = torch.cuda.get_device_name(current_device) print(f"CUDA Device Name: {device_name}") # Get the number of CUDA devices device_count = torch.cuda.device_count() print(f"Number of CUDA Devices: {device_count}") # Get CUDA capability cuda_capability = torch.cuda.get_device_capability(current_device) print(f"CUDA Capability: {cuda_capability}") # Get maximum memory allocated max_memory = torch.cuda.max_memory_allocated(current_device) print(f"Maximum Memory Allocated: {max_memory / 1024**2:.2f} MB") else: print("CUDA is not available. PyTorch will run on CPU only.")
def download_mnist_coco(root_dir='mnist_coco', split='train'): """ Download MNIST dataset and convert it to COCO format Args: root_dir: Directory to save the dataset split: 'train' or 'test' Returns: None (saves dataset to disk) """ # Create directories os.makedirs(os.path.join(root_dir, 'images'), exist_ok=True) os.makedirs(os.path.join(root_dir, 'annotations'), exist_ok=True) # Download MNIST dataset mnist_dataset = datasets.MNIST( root='./mnist_raw', train=(split == 'train'), download=True ) # Initialize COCO format dictionary coco_format = { "images": [], "annotations": [], "categories": [ {"id": i, "name": str(i), "supercategory": "digit"} for i in range(10) ] } # Convert MNIST to COCO format annotation_id = 1 print(f"Converting {split} split to COCO format...") for idx in tqdm(range(len(mnist_dataset))): # Get image and label image, label = mnist_dataset[idx] # Create image filename image_filename = f"{split}_{idx:06d}.png" image_path = os.path.join(root_dir, 'images', image_filename) # Save image image.save(image_path) # Get image dimensions width, height = image.size # Add image info to COCO format image_info = { "id": idx, "file_name": image_filename, "width": width, "height": height, } coco_format["images"].append(image_info) # Convert image to numpy array to get bounding box image_array = np.array(image) rows = np.any(image_array, axis=1) cols = np.any(image_array, axis=0) ymin, ymax = np.where(rows)[0][[0, -1]] xmin, xmax = np.where(cols)[0][[0, -1]] # Add annotation info to COCO format annotation_info = { "id": annotation_id, "image_id": idx, "category_id": int(label), "bbox": [int(xmin), int(ymin), int(xmax - xmin), int(ymax - ymin)], "area": int((xmax - xmin) * (ymax - ymin)), "iscrowd": 0 } coco_format["annotations"].append(annotation_info) annotation_id += 1 # Save annotations to JSON file annotation_file = os.path.join(root_dir, 'annotations', f'instances_{split}.json') with open(annotation_file, 'w') as f: json.dump(coco_format, f) print(f"Dataset saved in COCO format at: {root_dir}") print(f"Total images: {len(coco_format['images'])}") print(f"Total annotations: {len(coco_format['annotations'])}")
def display_mnist_coco(root_dir='mnist_coco', split='train', num_images=4): """ Display random images from the MNIST COCO dataset with their bounding boxes Args: root_dir: Directory containing the dataset split: 'train' or 'test' num_images: Number of images to display """ # Load annotations annotation_file = os.path.join(root_dir, 'annotations', f'instances_{split}.json') with open(annotation_file, 'r') as f: coco_data = json.load(f) # Create image id to annotations mapping image_to_anns = {} for ann in coco_data['annotations']: image_id = ann['image_id'] if image_id not in image_to_anns: image_to_anns[image_id] = [] image_to_anns[image_id].append(ann) # Select random images selected_images = random.sample(coco_data['images'], min(num_images, len(coco_data['images']))) # Create subplot grid fig, axes = plt.subplots(2, 2, figsize=(10, 10)) axes = axes.ravel() for idx, img_info in enumerate(selected_images): if idx >= num_images: break # Load image image_path = os.path.join(root_dir, 'images', img_info['file_name']) image = Image.open(image_path) # Display image axes[idx].imshow(image, cmap='gray') # Get annotations for this image annotations = image_to_anns.get(img_info['id'], []) # Draw bounding boxes and labels for ann in annotations: bbox = ann['bbox'] # [x, y, width, height] category_id = ann['category_id'] # Create rectangle patch rect = plt.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill=False, edgecolor='red', linewidth=2) axes[idx].add_patch(rect) # Add label axes[idx].text(bbox[0], bbox[1]-5, f'Digit: {category_id}', color='red', fontsize=10, bbox=dict(facecolor='white', alpha=0.7)) axes[idx].set_title(f'{split.capitalize()} Image {img_info["id"]}') axes[idx].axis('off') plt.tight_layout() plt.show()
if __name__ == "__main__": check_cuda_availability() # Example usage # download_mnist_coco(root_dir='mnist_coco', split='train') # Download training set # download_mnist_coco(root_dir='mnist_coco', split='test') # Download test set # Display random images from both train and test sets print("\nDisplaying training images:") display_mnist_coco(split='train') print("\nDisplaying test images:") display_mnist_coco(split='test')
Run
python mnist_download.py
to create the dataset directorymnist_coco/
containingimages/
andannotations/
.
๐ง 2. Designing a Simple Object Detection Model
Now, we’ll build a simple model that predicts both a bounding box and digit class for each image. The architecture includes:
-
Convolutional layers to extract features.
-
A shared fully connected layer.
-
Two heads: one for bounding box regression and one for classification.
๐ก ObjectDetection.py
import torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderimport torchvision.transforms as transformsimport osimport jsonfrom PIL import Imageimport numpy as npimport matplotlib.pyplot as pltfrom tqdm import tqdmimport timefrom datetime import datetime, timedelta
class MNISTCOCODataset(Dataset): def __init__(self, root_dir, split='train', transform=None): """ Args: root_dir: Directory with all the images and annotations split: 'train' or 'test' transform: Optional transform to be applied on images """ self.root_dir = root_dir self.split = split self.transform = transform # Load annotations ann_file = os.path.join(root_dir, 'annotations', f'instances_{split}.json') with open(ann_file, 'r') as f: self.coco_data = json.load(f) # Create image_id to annotation mapping self.image_to_ann = {} for ann in self.coco_data['annotations']: self.image_to_ann[ann['image_id']] = ann self.images = self.coco_data['images'] def __len__(self): return len(self.images) def __getitem__(self, idx): # Load image img_info = self.images[idx] img_path = os.path.join(self.root_dir, 'images', img_info['file_name']) image = Image.open(img_path).convert('L') # Convert to grayscale # Get annotation ann = self.image_to_ann[img_info['id']] bbox = torch.tensor(ann['bbox'], dtype=torch.float32) # [x, y, width, height] label = torch.tensor(ann['category_id'], dtype=torch.long) # Normalize bbox coordinates bbox[0] /= image.width # x bbox[1] /= image.height # y bbox[2] /= image.width # width bbox[3] /= image.height # height if self.transform: image = self.transform(image) return image, bbox, label
class SimpleDetector(nn.Module): def __init__(self, num_classes=10): super(SimpleDetector, self).__init__() # CNN Feature Extractor self.features = nn.Sequential( nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=2, stride=2) ) # Fully connected layers self.classifier = nn.Sequential( nn.Linear(128 * 3 * 3, 512), nn.ReLU(inplace=True), nn.Dropout(0.5) ) # Separate heads for bbox regression and classification self.bbox_head = nn.Linear(512, 4) # 4 for [x, y, width, height] self.class_head = nn.Linear(512, num_classes) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) x = self.classifier(x) bbox = self.bbox_head(x) bbox = torch.sigmoid(bbox) # Normalize bbox predictions to [0, 1] cls_score = self.class_head(x) return bbox, cls_score
def format_time(seconds): """Convert seconds to human readable time format""" return str(timedelta(seconds=int(seconds)))
def train_model(model, train_loader, val_loader, num_epochs=10, device='cuda', accuracy_threshold=98.0): criterion_bbox = nn.MSELoss() criterion_cls = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) model = model.to(device) best_loss = float('inf') best_accuracy = 0.0 # For storing metrics metrics_history = { 'train_bbox_loss': [], 'train_cls_loss': [], 'train_total_loss': [], 'val_bbox_loss': [], 'val_cls_loss': [], 'val_total_loss': [], 'train_accuracy': [], 'val_accuracy': [] } total_start_time = time.time() total_batches = len(train_loader) * num_epochs batches_done = 0 print("\n=== Training Start ===") print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print(f"Total epochs: {num_epochs}") print(f"Batches per epoch: {len(train_loader)}") print(f"Total batches: {total_batches}") print(f"Early stopping accuracy threshold: {accuracy_threshold}%") print("=" * 50) for epoch in range(num_epochs): epoch_start = time.time() model.train() train_loss = 0.0 train_bbox_loss = 0.0 train_cls_loss = 0.0 correct_preds = 0 total_samples = 0 # Calculate elapsed and estimated time elapsed_time = time.time() - total_start_time if batches_done > 0: time_per_batch = elapsed_time / batches_done eta = time_per_batch * (total_batches - batches_done) else: eta = 0 print(f"\nEpoch {epoch+1}/{num_epochs}") print(f"Elapsed time: {format_time(elapsed_time)}") print(f"Estimated time remaining: {format_time(eta)}") print(f"Expected completion: {datetime.now() + timedelta(seconds=int(eta))}") # Training progress bar train_pbar = tqdm(train_loader, desc=f'Train [{format_time(elapsed_time)} < {format_time(eta)}]', leave=False, ncols=100) for batch_idx, (images, bboxes, labels) in enumerate(train_pbar): batch_start = time.time() images = images.to(device) bboxes = bboxes.to(device) labels = labels.to(device) optimizer.zero_grad() pred_bboxes, pred_cls = model(images) loss_bbox = criterion_bbox(pred_bboxes, bboxes) loss_cls = criterion_cls(pred_cls, labels) loss = loss_bbox + loss_cls loss.backward() optimizer.step() # Update metrics train_loss += loss.item() train_bbox_loss += loss_bbox.item() train_cls_loss += loss_cls.item() # Calculate accuracy pred_labels = torch.argmax(pred_cls, dim=1) correct_preds += (pred_labels == labels).sum().item() total_samples += labels.size(0) # Update batch counter batches_done += 1 # Calculate speeds and times batch_time = time.time() - batch_start images_per_sec = images.shape[0] / batch_time # Update progress bar with detailed metrics train_pbar.set_postfix({ 'loss': f'{loss.item():.4f}', 'bbox': f'{loss_bbox.item():.4f}', 'cls': f'{loss_cls.item():.4f}', 'img/s': f'{images_per_sec:.1f}' }) # Print detailed metrics every 50 batches if (batch_idx + 1) % 50 == 0: elapsed = time.time() - total_start_time progress = batches_done / total_batches eta = (elapsed / progress) - elapsed if progress > 0 else 0 print(f"\nBatch {batch_idx + 1}/{len(train_loader)} Statistics:") print(f" Time elapsed: {format_time(elapsed)}") print(f" Time remaining: {format_time(eta)}") print(f" Processing speed: {images_per_sec:.1f} images/sec") print(f" Loss: {loss.item():.4f}") print(f" Bbox Loss: {loss_bbox.item():.4f}") print(f" Class Loss: {loss_cls.item():.4f}") print(f" Accuracy: {100 * correct_preds / total_samples:.2f}%") train_pbar.close() # Calculate average training metrics avg_train_loss = train_loss / len(train_loader) avg_train_bbox_loss = train_bbox_loss / len(train_loader) avg_train_cls_loss = train_cls_loss / len(train_loader) train_accuracy = 100 * correct_preds / total_samples # Validation model.eval() val_loss = 0.0 val_bbox_loss = 0.0 val_cls_loss = 0.0 val_correct_preds = 0 val_total_samples = 0 # Validation progress bar val_pbar = tqdm(val_loader, desc=f'Val [Epoch {epoch+1}/{num_epochs}]', leave=False, ncols=100) val_start = time.time() with torch.no_grad(): for images, bboxes, labels in val_pbar: images = images.to(device) bboxes = bboxes.to(device) labels = labels.to(device) pred_bboxes, pred_cls = model(images) loss_bbox = criterion_bbox(pred_bboxes, bboxes) loss_cls = criterion_cls(pred_cls, labels) loss = loss_bbox + loss_cls # Update metrics val_loss += loss.item() val_bbox_loss += loss_bbox.item() val_cls_loss += loss_cls.item() # Calculate accuracy pred_labels = torch.argmax(pred_cls, dim=1) val_correct_preds += (pred_labels == labels).sum().item() val_total_samples += labels.size(0) val_pbar.set_postfix({ 'loss': f'{loss.item():.4f}', 'bbox': f'{loss_bbox.item():.4f}', 'cls': f'{loss_cls.item():.4f}' }) val_pbar.close() val_time = time.time() - val_start # Calculate average validation metrics avg_val_loss = val_loss / len(val_loader) avg_val_bbox_loss = val_bbox_loss / len(val_loader) avg_val_cls_loss = val_cls_loss / len(val_loader) val_accuracy = 100 * val_correct_preds / val_total_samples # Update metrics history metrics_history['train_bbox_loss'].append(avg_train_bbox_loss) metrics_history['train_cls_loss'].append(avg_train_cls_loss) metrics_history['train_total_loss'].append(avg_train_loss) metrics_history['val_bbox_loss'].append(avg_val_bbox_loss) metrics_history['val_cls_loss'].append(avg_val_cls_loss) metrics_history['val_total_loss'].append(avg_val_loss) metrics_history['train_accuracy'].append(train_accuracy) metrics_history['val_accuracy'].append(val_accuracy) # Print epoch summary with timing information epoch_time = time.time() - epoch_start elapsed_total = time.time() - total_start_time eta_total = (elapsed_total / (epoch + 1)) * (num_epochs - (epoch + 1)) print(f'\nEpoch {epoch+1}/{num_epochs} Complete:') print(f'Time Statistics:') print(f' Epoch time: {format_time(epoch_time)}') print(f' Training time: {format_time(epoch_time - val_time)}') print(f' Validation time: {format_time(val_time)}') print(f' Total elapsed: {format_time(elapsed_total)}') print(f' Estimated remaining: {format_time(eta_total)}') print(f'Training Metrics:') print(f' Total Loss: {avg_train_loss:.4f}') print(f' Bbox Loss: {avg_train_bbox_loss:.4f}') print(f' Class Loss: {avg_train_cls_loss:.4f}') print(f' Accuracy: {train_accuracy:.2f}%') print(f'Validation Metrics:') print(f' Total Loss: {avg_val_loss:.4f}') print(f' Bbox Loss: {avg_val_bbox_loss:.4f}') print(f' Class Loss: {avg_val_cls_loss:.4f}') print(f' Accuracy: {val_accuracy:.2f}%') # After calculating val_accuracy, add early stopping check if val_accuracy >= accuracy_threshold: print(f"\n=== Early Stopping ===") print(f"Validation accuracy {val_accuracy:.2f}% reached threshold {accuracy_threshold}%") print(f"Stopping training at epoch {epoch+1}/{num_epochs}") # Save final model if it's the best if val_accuracy > best_accuracy: best_accuracy = val_accuracy torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_loss': best_loss, 'best_accuracy': best_accuracy, 'metrics_history': metrics_history, 'training_time': time.time() - total_start_time }, 'best_model.pth') print(f'Final model saved with accuracy: {best_accuracy:.2f}%') break # Update best accuracy if val_accuracy > best_accuracy: best_accuracy = val_accuracy # Save best model (modified to include accuracy) if avg_val_loss < best_loss: best_loss = avg_val_loss torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'best_loss': best_loss, 'best_accuracy': best_accuracy, 'metrics_history': metrics_history, 'training_time': time.time() - total_start_time }, 'best_model.pth') print(f'New best model saved! (Val Loss: {best_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%)') print('=' * 80) total_time = time.time() - total_start_time print('\n=== Training Complete ===') print(f'Total training time: {format_time(total_time)}') print(f'End time: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}') print(f'Best validation loss: {best_loss:.4f}') print(f'Best validation accuracy: {best_accuracy:.2f}%') print('=' * 50) return metrics_history
def visualize_predictions(model, test_loader, device, num_images=4): model.eval() fig, axes = plt.subplots(2, 2, figsize=(10, 10)) axes = axes.ravel() with torch.no_grad(): for idx, (images, bboxes, labels) in enumerate(test_loader): if idx >= num_images: break images = images.to(device) pred_bboxes, pred_cls = model(images) # Get predictions pred_bbox = pred_bboxes[0].cpu().numpy() pred_label = torch.argmax(pred_cls[0]).item() true_bbox = bboxes[0].numpy() true_label = labels[0].item() # Display image img = images[0].cpu().squeeze() axes[idx].imshow(img, cmap='gray') # Draw predicted bbox (red) and true bbox (green) h, w = img.shape pred_rect = plt.Rectangle((pred_bbox[0]*w, pred_bbox[1]*h), pred_bbox[2]*w, pred_bbox[3]*h, fill=False, edgecolor='red', linewidth=2) true_rect = plt.Rectangle((true_bbox[0]*w, true_bbox[1]*h), true_bbox[2]*w, true_bbox[3]*h, fill=False, edgecolor='green', linewidth=2) axes[idx].add_patch(pred_rect) axes[idx].add_patch(true_rect) axes[idx].set_title(f'Pred: {pred_label}, True: {true_label}') axes[idx].axis('off') plt.tight_layout() plt.show()
def plot_training_history(metrics_history): """Plot training and validation metrics history""" plt.figure(figsize=(15, 5)) # Plot losses plt.subplot(1, 2, 1) plt.plot(metrics_history['train_total_loss'], label='Train Loss') plt.plot(metrics_history['val_total_loss'], label='Val Loss') plt.plot(metrics_history['train_bbox_loss'], label='Train Bbox Loss') plt.plot(metrics_history['train_cls_loss'], label='Train Cls Loss') plt.title('Training and Validation Losses') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() # Plot accuracies plt.subplot(1, 2, 2) plt.plot(metrics_history['train_accuracy'], label='Train Accuracy') plt.plot(metrics_history['val_accuracy'], label='Val Accuracy') plt.title('Training and Validation Accuracy') plt.xlabel('Epoch') plt.ylabel('Accuracy (%)') plt.legend() plt.tight_layout() plt.show()
def load_and_visualize_model(model_path='best_model.pth', data_root='mnist_coco', num_images=4): """ Load a trained model and visualize its predictions Args: model_path: Path to the saved model checkpoint data_root: Directory containing the dataset num_images: Number of images to visualize """ # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") # Create model and load weights model = SimpleDetector() checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model = model.to(device) model.eval() # Print model information print("\n=== Model Information ===") print(f"Checkpoint from epoch: {checkpoint['epoch'] + 1}") print(f"Best validation loss: {checkpoint['best_loss']:.4f}") print(f"Best validation accuracy: {checkpoint['best_accuracy']:.2f}%") print(f"Total training time: {format_time(checkpoint['training_time'])}") # Create test dataset and loader transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), ]) # Create test dataset test_dataset = MNISTCOCODataset(root_dir=data_root, split='test', transform=transform) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True) # Visualize predictions print("\n=== Visualizing Predictions ===") fig, axes = plt.subplots(2, 2, figsize=(12, 12)) axes = axes.ravel() with torch.no_grad(): for idx, (images, bboxes, labels) in enumerate(test_loader): if idx >= num_images: break images = images.to(device) pred_bboxes, pred_cls = model(images) # Get predictions pred_bbox = pred_bboxes[0].cpu().numpy() pred_label = torch.argmax(pred_cls[0]).item() true_bbox = bboxes[0].numpy() true_label = labels[0].item() # Display image img = images[0].cpu().squeeze() axes[idx].imshow(img, cmap='gray') # Draw predicted bbox (red) and true bbox (green) h, w = img.shape pred_rect = plt.Rectangle((pred_bbox[0]*w, pred_bbox[1]*h), pred_bbox[2]*w, pred_bbox[3]*h, fill=False, edgecolor='red', linewidth=2, label='Prediction') true_rect = plt.Rectangle((true_bbox[0]*w, true_bbox[1]*h), true_bbox[2]*w, true_bbox[3]*h, fill=False, edgecolor='green', linewidth=2, label='Ground Truth') axes[idx].add_patch(pred_rect) axes[idx].add_patch(true_rect) # Add title with predictions axes[idx].set_title(f'Pred: {pred_label} (True: {true_label})\n' + f'{"✓" if pred_label == true_label else "✗"}', color='green' if pred_label == true_label else 'red') # Add legend if idx == 0: axes[idx].legend() axes[idx].axis('off') plt.suptitle('Model Predictions on Test Set\nRed: Predicted Box, Green: True Box', fontsize=12, y=0.95) plt.tight_layout() plt.show() # Calculate and display test metrics print("\n=== Test Set Metrics ===") correct = 0 total = 0 bbox_error = 0 with torch.no_grad(): for images, bboxes, labels in tqdm(test_loader, desc="Evaluating"): images = images.to(device) bboxes = bboxes.to(device) labels = labels.to(device) pred_bboxes, pred_cls = model(images) pred_labels = torch.argmax(pred_cls, dim=1) # Classification accuracy correct += (pred_labels == labels).sum().item() total += labels.size(0) # Bounding box error (MSE) bbox_error += nn.MSELoss()(pred_bboxes, bboxes).item() test_accuracy = 100 * correct / total avg_bbox_error = bbox_error / len(test_loader) print(f"Test Accuracy: {test_accuracy:.2f}%") print(f"Average Bounding Box Error: {avg_bbox_error:.4f}") return model
def analyze_failure_cases(model_path='best_model.pth', data_root='mnist_coco', num_failures=8): """ Analyze and display failure cases where the model made incorrect predictions Args: model_path: Path to the saved model checkpoint data_root: Directory containing the dataset num_failures: Number of failure cases to display """ # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Create model and load weights model = SimpleDetector() checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) model = model.to(device) model.eval() # Create test dataset and loader transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), ]) test_dataset = MNISTCOCODataset(root_dir=data_root, split='test', transform=transform) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True) # Collect failure cases failure_cases = [] confusion_matrix = torch.zeros(10, 10) # 10 classes for MNIST bbox_errors = [] print("Analyzing model predictions...") with torch.no_grad(): for images, bboxes, labels in tqdm(test_loader, desc="Finding failures"): images = images.to(device) bboxes = bboxes.to(device) labels = labels.to(device) pred_bboxes, pred_cls = model(images) pred_labels = torch.argmax(pred_cls, dim=1) # Calculate bbox error bbox_error = nn.MSELoss()(pred_bboxes, bboxes).item() # Update confusion matrix confusion_matrix[labels.item(), pred_labels.item()] += 1 # Check if prediction is wrong if pred_labels.item() != labels.item() or bbox_error > 0.1: # threshold for significant bbox error failure_cases.append({ 'image': images[0].cpu(), 'true_label': labels.item(), 'pred_label': pred_labels.item(), 'true_bbox': bboxes[0].cpu().numpy(), 'pred_bbox': pred_bboxes[0].cpu().numpy(), 'bbox_error': bbox_error }) bbox_errors.append(bbox_error) # Sort failure cases by bbox error failure_cases.sort(key=lambda x: x['bbox_error'], reverse=True) # Calculate statistics total_failures = len(failure_cases) total_samples = len(test_loader) failure_rate = (total_failures / total_samples) * 100 print("\n=== Failure Analysis ===") print(f"Total samples tested: {total_samples}") print(f"Number of failures: {total_failures}") print(f"Failure rate: {failure_rate:.2f}%") if total_failures == 0: print("No failures found!") return # Display worst failure cases num_cols = 4 num_rows = (min(num_failures, len(failure_cases)) + num_cols - 1) // num_cols fig = plt.figure(figsize=(15, 4 * num_rows)) for idx, case in enumerate(failure_cases[:num_failures]): plt.subplot(num_rows, num_cols, idx + 1) # Display image img = case['image'].squeeze() plt.imshow(img, cmap='gray') # Draw bounding boxes h, w = img.shape pred_bbox = case['pred_bbox'] true_bbox = case['true_bbox'] # Predicted bbox in red pred_rect = plt.Rectangle((pred_bbox[0]*w, pred_bbox[1]*h), pred_bbox[2]*w, pred_bbox[3]*h, fill=False, edgecolor='red', linewidth=2, label='Prediction') # True bbox in green true_rect = plt.Rectangle((true_bbox[0]*w, true_bbox[1]*h), true_bbox[2]*w, true_bbox[3]*h, fill=False, edgecolor='green', linewidth=2, label='Ground Truth') plt.gca().add_patch(pred_rect) plt.gca().add_patch(true_rect) # Add title with error information plt.title(f'True: {case["true_label"]}, Pred: {case["pred_label"]}\n' + f'BBox Error: {case["bbox_error"]:.4f}', color='red') if idx == 0: plt.legend() plt.axis('off') plt.suptitle('Worst Failure Cases\nRed: Predicted Box, Green: True Box', fontsize=14, y=1.02) plt.tight_layout() # Plot confusion matrix plt.figure(figsize=(10, 8)) confusion_matrix_normalized = confusion_matrix / confusion_matrix.sum(dim=1, keepdim=True) plt.imshow(confusion_matrix_normalized.numpy(), cmap='YlOrRd') plt.colorbar() plt.title('Confusion Matrix (Normalized)') plt.xlabel('Predicted Label') plt.ylabel('True Label') # Add text annotations to the confusion matrix for i in range(10): for j in range(10): plt.text(j, i, f'{confusion_matrix_normalized[i, j]:.2f}', ha='center', va='center') # Plot bbox error distribution if bbox_errors: plt.figure(figsize=(10, 5)) plt.hist(bbox_errors, bins=50, edgecolor='black') plt.title('Distribution of Bounding Box Errors in Failure Cases') plt.xlabel('Bounding Box Error (MSE)') plt.ylabel('Count') plt.show() # Print detailed statistics print("\nFailure Type Analysis:") print(f"Classification errors: {sum(case['true_label'] != case['pred_label'] for case in failure_cases)}") print(f"Significant bbox errors (>0.1): {sum(case['bbox_error'] > 0.1 for case in failure_cases)}") if bbox_errors: print(f"\nBounding Box Error Statistics:") print(f"Mean error: {np.mean(bbox_errors):.4f}") print(f"Median error: {np.median(bbox_errors):.4f}") print(f"Max error: {np.max(bbox_errors):.4f}") print(f"Min error: {np.min(bbox_errors):.4f}")
def main(): # Set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # device = torch.device('cpu') print(f"Using device: {device}") # Data transforms transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), ]) # Create datasets train_dataset = MNISTCOCODataset(root_dir='mnist_coco', split='train', transform=transform) test_dataset = MNISTCOCODataset(root_dir='mnist_coco', split='test', transform=transform) # Create data loaders train_size = int(0.8 * len(train_dataset)) val_size = len(train_dataset) - train_size train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_size, val_size]) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=32) test_loader = DataLoader(test_dataset, batch_size=1) # Create and train model with accuracy threshold model = SimpleDetector() metrics_history = train_model(model, train_loader, val_loader, num_epochs=10, device=device, accuracy_threshold=98.0) # Set accuracy threshold to 98% # Plot training history plot_training_history(metrics_history) # Load best model and visualize predictions checkpoint = torch.load('best_model.pth') model.load_state_dict(checkpoint['model_state_dict']) print(f"\nBest model performance:") print(f"Validation Loss: {checkpoint['best_loss']:.4f}") print(f"Validation Accuracy: {checkpoint['best_accuracy']:.2f}%") visualize_predictions(model, test_loader, device)
if __name__ == "__main__": # Either run training or analyze existing model if not os.path.exists('best_model.pth'): # Run training main() else: # Load and visualize model, then analyze failures print("\nVisualizing model predictions...") load_and_visualize_model('best_model.pth') print("\nAnalyzing failure cases...") analyze_failure_cases('best_model.pth')
Save this file as ObjectDetection.py
. This file includes:
-
MNISTCOCODataset
: Custom dataset loader. -
SimpleDetector
: A CNN with two heads (bbox and class). -
train_model
: Training loop with detailed logging and early stopping. -
visualize_predictions
: Shows true and predicted boxes. -
analyze_failure_cases
: Displays worst predictions and computes confusion matrix. -
main
: Runs training, saves the model, plots results.
▶️ 3. How to Run
Once you've downloaded the dataset and written both scripts:
๐ Option A: Train the Model
python ObjectDetection.py
If best_model.pth
doesn’t exist, this will train a model from scratch and save it once it reaches your defined accuracy threshold (default: 98%).
๐ Option B: Visualize and Analyze
If best_model.pth
exists:
python ObjectDetection.py
It will:
-
Load the trained model
-
Visualize predictions vs. ground truth
-
Show failure cases with bounding box errors and a confusion matrix
๐ Example Output
-
Training Graphs (loss and accuracy)
-
Visual Predictions: Red = predicted box, Green = ground truth
-
Failure Cases: Includes misclassified digits or inaccurate bounding boxes
-
Confusion Matrix: A heatmap showing which digits were confused
-
BBox Error Distribution: Histogram of errors
๐งช Why This Project is Valuable
This hands-on mini-project is perfect if you:
-
Want to learn object detection basics.
-
Are preparing to work with custom datasets in COCO format.
-
Want to test the full ML pipeline from data preparation to evaluation.
You’ll gain practical knowledge of:
-
COCO JSON annotation structure
-
Custom PyTorch datasets and dataloaders
-
Dual-loss training (regression + classification)
-
Visual debugging of model performance
✅ Requirements
Install the dependencies:
pip install torch torchvision matplotlib tqdm pillow
๐ฆ Directory Structure
.
├── mnist_download.py
├── ObjectDetection.py
├── mnist_coco/
│ ├── images/
│ └── annotations/
└── best_model.pth (generated after training)
๐ Final Words
This post walks you through a complete object detection pipeline using MNIST digits, giving you a reusable framework for other small detection tasks. You’ve learned how to:
-
Convert a dataset to COCO format.
-
Train a detection model with bounding box and class loss.
-
Visualize results and analyze performance issues.
Use this as a base for expanding into more complex datasets or refining your object detector with new architectures like SSD or YOLO!
Comments
Post a Comment