BLOG POSTS
AlexNet in PyTorch – Building a CNN Model

AlexNet in PyTorch – Building a CNN Model

AlexNet revolutionized computer vision in 2012 by demonstrating the power of deep convolutional neural networks on the ImageNet dataset. This pioneering architecture laid the groundwork for modern deep learning and remains an excellent starting point for understanding CNNs. You’ll learn how to implement AlexNet from scratch in PyTorch, understand its key architectural components, explore practical applications, and discover optimization techniques that make it production-ready for real-world computer vision tasks.

Understanding AlexNet Architecture

AlexNet consists of 8 layers: 5 convolutional layers followed by 3 fully connected layers. The network uses ReLU activation functions, dropout for regularization, and local response normalization. The architecture processes 224x224x3 input images and outputs class probabilities for 1000 ImageNet categories.

Key architectural innovations include:

  • ReLU activation functions instead of traditional sigmoid/tanh
  • Dropout layers to prevent overfitting
  • Data augmentation techniques
  • GPU acceleration using CUDA
  • Local Response Normalization (LRN)
Layer Type Output Size Parameters
Conv1 Convolution 55x55x96 11×11 kernel, stride 4
Conv2 Convolution 27x27x256 5×5 kernel, stride 1
Conv3 Convolution 13x13x384 3×3 kernel, stride 1
Conv4 Convolution 13x13x384 3×3 kernel, stride 1
Conv5 Convolution 13x13x256 3×3 kernel, stride 1
FC1 Fully Connected 4096 Dropout 0.5
FC2 Fully Connected 4096 Dropout 0.5
FC3 Fully Connected 1000 Output layer

Step-by-Step PyTorch Implementation

Let’s build AlexNet from scratch using PyTorch. Start by importing the necessary libraries and defining the network architecture:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

class AlexNet(nn.Module):
    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        
        self.features = nn.Sequential(
            # First convolutional layer
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            # Second convolutional layer
            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            # Third convolutional layer
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            # Fourth convolutional layer
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            
            # Fifth convolutional layer
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

Now let’s set up the training pipeline with data preprocessing and augmentation:

# Data preprocessing and augmentation
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset (example with CIFAR-10)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                       download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, 
                                      download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=4)

# Initialize model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlexNet(num_classes=10).to(device)  # CIFAR-10 has 10 classes
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

Implement the training loop with proper monitoring and validation:

def train_model(model, trainloader, testloader, criterion, optimizer, scheduler, num_epochs=100):
    best_acc = 0.0
    train_losses = []
    train_accuracies = []
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, '
                      f'Loss: {running_loss/(batch_idx+1):.3f}, '
                      f'Acc: {100.*correct/total:.2f}%')
        
        # Validation
        model.eval()
        test_loss = 0
        test_correct = 0
        test_total = 0
        
        with torch.no_grad():
            for inputs, targets in testloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                test_total += targets.size(0)
                test_correct += predicted.eq(targets).sum().item()
        
        test_acc = 100. * test_correct / test_total
        print(f'Epoch {epoch}: Train Acc: {100.*correct/total:.2f}%, '
              f'Test Acc: {test_acc:.2f}%')
        
        # Save best model
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), 'alexnet_best.pth')
        
        scheduler.step()
    
    return model

# Start training
trained_model = train_model(model, trainloader, testloader, criterion, optimizer, scheduler)

Real-World Applications and Use Cases

AlexNet serves as an excellent foundation for various computer vision applications. Here are practical implementations:

Image Classification Pipeline:

def classify_image(model, image_path, class_names):
    model.eval()
    
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
        
    # Get top 5 predictions
    top5_prob, top5_indices = torch.topk(probabilities, 5)
    
    results = []
    for i in range(top5_prob.size(0)):
        results.append({
            'class': class_names[top5_indices[i].item()],
            'probability': top5_prob[i].item()
        })
    
    return results

# Example usage
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']
predictions = classify_image(trained_model, 'test_image.jpg', class_names)
for pred in predictions:
    print(f"{pred['class']}: {pred['probability']:.4f}")

Feature Extraction for Transfer Learning:

class AlexNetFeatureExtractor(nn.Module):
    def __init__(self, pretrained_model):
        super(AlexNetFeatureExtractor, self).__init__()
        self.features = pretrained_model.features
        self.avgpool = pretrained_model.avgpool
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x

# Extract features for custom dataset
feature_extractor = AlexNetFeatureExtractor(trained_model)
feature_extractor.eval()

def extract_features(dataloader, feature_extractor):
    features = []
    labels = []
    
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            batch_features = feature_extractor(inputs)
            features.append(batch_features.cpu())
            labels.append(targets)
    
    return torch.cat(features), torch.cat(labels)

# Use extracted features for downstream tasks
extracted_features, extracted_labels = extract_features(testloader, feature_extractor)

Performance Comparison and Optimization

Modern alternatives to AlexNet offer better performance, but understanding the trade-offs helps choose the right architecture:

Model Parameters ImageNet Top-1 Accuracy Inference Time (ms) Memory Usage (MB)
AlexNet 61M 56.5% 2.3 217
ResNet-18 11.7M 69.8% 1.8 44
EfficientNet-B0 5.3M 77.3% 3.1 20
MobileNet-V2 3.5M 72.0% 1.2 14

Optimize AlexNet performance with these techniques:

# Mixed precision training for faster training
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

def train_with_amp(model, trainloader, criterion, optimizer):
    model.train()
    
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        
        with autocast():
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

# Model quantization for deployment
def quantize_model(model):
    model.eval()
    
    # Post-training quantization
    quantized_model = torch.quantization.quantize_dynamic(
        model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8
    )
    
    return quantized_model

# Pruning to reduce model size
import torch.nn.utils.prune as prune

def prune_model(model, pruning_ratio=0.2):
    for module in model.modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name='weight', amount=pruning_ratio)
            prune.remove(module, 'weight')
    
    return model

Common Issues and Troubleshooting

Address frequent problems encountered when implementing AlexNet:

Memory Issues:

  • Reduce batch size if encountering CUDA out of memory errors
  • Use gradient checkpointing for memory-efficient training
  • Clear cache periodically with torch.cuda.empty_cache()
# Memory-efficient training
def train_with_gradient_checkpointing(model, trainloader, criterion, optimizer):
    model.train()
    
    # Enable gradient checkpointing
    model.features.register_full_backward_hook(lambda module, grad_input, grad_output: None)
    
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        
        # Process in smaller chunks if needed
        chunk_size = 32
        total_loss = 0
        
        for i in range(0, inputs.size(0), chunk_size):
            chunk_inputs = inputs[i:i+chunk_size]
            chunk_targets = targets[i:i+chunk_size]
            
            outputs = model(chunk_inputs)
            loss = criterion(outputs, chunk_targets) / (inputs.size(0) // chunk_size)
            loss.backward()
            total_loss += loss.item()
        
        optimizer.step()
        
        # Clear cache periodically
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

Training Stability:

# Gradient clipping for stable training
def train_with_gradient_clipping(model, trainloader, criterion, optimizer, max_norm=1.0):
    model.train()
    
    for inputs, targets in trainloader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        
        optimizer.step()

# Learning rate scheduling
def cosine_annealing_scheduler(optimizer, T_max, eta_min=0):
    return optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max, eta_min=eta_min)

# Warmup learning rate
class WarmupScheduler:
    def __init__(self, optimizer, warmup_epochs, base_lr):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.base_lr = base_lr
        self.current_epoch = 0
    
    def step(self):
        if self.current_epoch < self.warmup_epochs:
            lr = self.base_lr * (self.current_epoch + 1) / self.warmup_epochs
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        self.current_epoch += 1

Best Practices and Production Deployment

When deploying AlexNet models in production environments, especially on dedicated server infrastructure, consider these optimization strategies:

# Model serving with TorchScript
def export_for_production(model, sample_input):
    model.eval()
    
    # Convert to TorchScript
    traced_model = torch.jit.trace(model, sample_input)
    
    # Optimize for inference
    optimized_model = torch.jit.optimize_for_inference(traced_model)
    
    # Save the model
    torch.jit.save(optimized_model, 'alexnet_production.pt')
    
    return optimized_model

# Load and use in production
def load_production_model(model_path):
    model = torch.jit.load(model_path)
    model.eval()
    return model

# Batch inference for high throughput
def batch_inference(model, image_batch):
    model.eval()
    
    with torch.no_grad():
        outputs = model(image_batch.to(device))
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
    
    return probabilities.cpu().numpy()

For high-performance deployments on dedicated servers, implement these server-side optimizations:

  • Use ONNX Runtime for cross-platform deployment
  • Implement model caching and connection pooling
  • Set up horizontal scaling with load balancers
  • Monitor GPU utilization and memory usage
  • Configure automatic model versioning and rollback

The PyTorch documentation provides comprehensive guides for model optimization and deployment at https://pytorch.org/docs/stable/index.html. For GPU-accelerated training on cloud infrastructure, VPS services with CUDA support offer cost-effective solutions for development and small-scale production workloads.

AlexNet remains valuable for educational purposes and as a starting point for custom CNN architectures. While modern networks offer superior performance, understanding AlexNet's principles provides essential insights into deep learning fundamentals and helps build more sophisticated computer vision systems.



This article incorporates information and material from various online sources. We acknowledge and appreciate the work of all original authors, publishers, and websites. While every effort has been made to appropriately credit the source material, any unintentional oversight or omission does not constitute a copyright infringement. All trademarks, logos, and images mentioned are the property of their respective owners. If you believe that any content used in this article infringes upon your copyright, please contact us immediately for review and prompt action.

This article is intended for informational and educational purposes only and does not infringe on the rights of the copyright owners. If any copyrighted material has been used without proper credit or in violation of copyright laws, it is unintentional and we will rectify it promptly upon notification. Please note that the republishing, redistribution, or reproduction of part or all of the contents in any form is prohibited without express written permission from the author and website owner. For permissions or further inquiries, please contact us.

Leave a reply

Your email address will not be published. Required fields are marked