BLOG POSTS
Writing ResNet from Scratch in PyTorch – Tutorial

Writing ResNet from Scratch in PyTorch – Tutorial

ResNet, or Residual Networks, revolutionized deep learning by solving the vanishing gradient problem that plagued training of very deep neural networks. By introducing skip connections that allow gradients to flow directly through the network, ResNet enabled the training of networks with hundreds of layers while maintaining or improving performance. In this comprehensive tutorial, you’ll learn how to implement ResNet-18 and ResNet-50 from scratch in PyTorch, understand the mathematical foundations behind residual blocks, and explore practical deployment strategies for your deep learning infrastructure.

Understanding ResNet Architecture and Skip Connections

The core innovation of ResNet lies in its residual blocks, which implement skip connections that bypass one or more layers. Instead of learning a direct mapping H(x), residual blocks learn the residual function F(x) = H(x) – x, making it easier for the network to learn identity mappings when needed.

ResNet architectures follow a specific pattern with four main stages of residual blocks, each operating at different spatial resolutions. The network begins with a 7×7 convolutional layer, followed by max pooling, then four stages of residual blocks with increasing channel dimensions and decreasing spatial dimensions.

ResNet Variant Layers Parameters (M) Top-1 Accuracy FLOPs (G)
ResNet-18 18 11.7 69.8% 1.8
ResNet-34 34 21.8 73.3% 3.7
ResNet-50 50 25.6 76.1% 4.1
ResNet-101 101 44.5 77.4% 7.8

Implementing Basic Building Blocks

Let’s start by implementing the fundamental components. ResNet uses two types of residual blocks: BasicBlock for ResNet-18/34 and Bottleneck for deeper variants like ResNet-50/101/152.

import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        
        # First convolutional layer
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        # Second convolutional layer
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.downsample = downsample
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        identity = x
        
        # First conv block
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        # Second conv block
        out = self.conv2(out)
        out = self.bn2(out)
        
        # Apply downsample to identity if needed
        if self.downsample is not None:
            identity = self.downsample(x)
        
        # Add skip connection
        out += identity
        out = self.relu(out)
        
        return out

The Bottleneck block uses three convolutional layers with a 1×1 → 3×3 → 1×1 pattern, reducing computational complexity while maintaining representational power:

class Bottleneck(nn.Module):
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        
        # 1x1 conv for dimension reduction
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        # 3x3 conv for spatial processing
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                              stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # 1x1 conv for dimension expansion
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
                              kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        
        self.downsample = downsample
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        out = self.conv3(out)
        out = self.bn3(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        
        return out

Complete ResNet Implementation

Now we’ll implement the full ResNet architecture with configurable depth and block types:

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
        super(ResNet, self).__init__()
        
        self.in_channels = 64
        
        # Initial convolutional layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Residual layers
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        # Classification head
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        # Initialize weights
        self._initialize_weights(zero_init_residual)
        
    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )
        
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        
        return nn.Sequential(*layers)
    
    def _initialize_weights(self, zero_init_residual):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
        # Zero-initialize the last BN in each residual branch
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

Create factory functions for different ResNet variants:

def resnet18(num_classes=1000, **kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs)

def resnet34(num_classes=1000, **kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, **kwargs)

def resnet50(num_classes=1000, **kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, **kwargs)

def resnet101(num_classes=1000, **kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, **kwargs)

def resnet152(num_classes=1000, **kwargs):
    return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, **kwargs)

Training Setup and Data Pipeline

Implementing an efficient training pipeline is crucial for ResNet performance. Here’s a complete training setup with proper data augmentation:

import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

def get_data_loaders(data_path, batch_size=256, num_workers=4):
    # Data augmentation for training
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Standard transforms for validation
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    train_dataset = datasets.ImageFolder(
        root=f'{data_path}/train',
        transform=train_transform
    )
    
    val_dataset = datasets.ImageFolder(
        root=f'{data_path}/val',
        transform=val_transform
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )
    
    return train_loader, val_loader

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}, Loss: {loss.item():.4f}, '
                  f'Acc: {100.*correct/total:.2f}%')
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

Advanced Training Techniques and Optimization

Modern ResNet training benefits from several advanced techniques. Here’s an implementation with learning rate scheduling, mixed precision training, and gradient clipping:

from torch.cuda.amp import GradScaler, autocast
import torch.optim.lr_scheduler as lr_scheduler

class ResNetTrainer:
    def __init__(self, model, train_loader, val_loader, device='cuda'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        
        # Loss and optimizer
        self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.optimizer = optim.SGD(
            model.parameters(), 
            lr=0.1, 
            momentum=0.9, 
            weight_decay=1e-4
        )
        
        # Learning rate scheduler
        self.scheduler = lr_scheduler.MultiStepLR(
            self.optimizer, 
            milestones=[30, 60, 80], 
            gamma=0.1
        )
        
        # Mixed precision training
        self.scaler = GradScaler()
        
        # Metrics tracking
        self.train_losses = []
        self.train_accs = []
        self.val_losses = []
        self.val_accs = []
    
    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for data, target in self.train_loader:
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            
            with autocast():
                output = self.model(data)
                loss = self.criterion(output, target)
            
            self.scaler.scale(loss).backward()
            
            # Gradient clipping
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            running_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
        
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = 100. * correct / total
        
        self.train_losses.append(epoch_loss)
        self.train_accs.append(epoch_acc)
        
        return epoch_loss, epoch_acc
    
    def validate(self):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in self.val_loader:
                data, target = data.to(self.device), target.to(self.device)
                
                with autocast():
                    output = self.model(data)
                    loss = self.criterion(output, target)
                
                running_loss += loss.item()
                _, predicted = output.max(1)
                total += target.size(0)
                correct += predicted.eq(target).sum().item()
        
        epoch_loss = running_loss / len(self.val_loader)
        epoch_acc = 100. * correct / total
        
        self.val_losses.append(epoch_loss)
        self.val_accs.append(epoch_acc)
        
        return epoch_loss, epoch_acc
    
    def train(self, epochs=90):
        best_acc = 0.0
        
        for epoch in range(epochs):
            print(f'Epoch {epoch+1}/{epochs}')
            
            train_loss, train_acc = self.train_epoch()
            val_loss, val_acc = self.validate()
            self.scheduler.step()
            
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
            print(f'LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
            
            # Save best model
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'best_acc': best_acc,
                }, 'best_resnet_model.pth')
            
            print('-' * 50)

Real-World Deployment and Optimization

For production deployment, ResNet models require optimization for inference speed and memory usage. Here are practical techniques for model deployment:

# Model quantization for faster inference
def quantize_model(model):
    model.eval()
    quantized_model = torch.quantization.quantize_dynamic(
        model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8
    )
    return quantized_model

# TorchScript compilation for production
def compile_model(model, example_input):
    model.eval()
    traced_model = torch.jit.trace(model, example_input)
    return traced_model

# Model inference with preprocessing
class ResNetInference:
    def __init__(self, model_path, device='cuda'):
        self.device = device
        self.model = self.load_model(model_path)
        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])
        ])
    
    def load_model(self, model_path):
        model = resnet50()  # Or any other variant
        checkpoint = torch.load(model_path, map_location=self.device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(self.device)
        model.eval()
        return model
    
    def predict(self, image):
        with torch.no_grad():
            input_tensor = self.transform(image).unsqueeze(0).to(self.device)
            output = self.model(input_tensor)
            probabilities = F.softmax(output, dim=1)
            predicted_class = output.argmax(dim=1)
            confidence = probabilities.max().item()
        
        return predicted_class.item(), confidence

# Batch inference for efficiency
def batch_inference(model, image_paths, batch_size=32, device='cuda'):
    results = []
    
    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i:i+batch_size]
        batch_tensors = []
        
        for path in batch_paths:
            image = Image.open(path).convert('RGB')
            tensor = transform(image)
            batch_tensors.append(tensor)
        
        batch_input = torch.stack(batch_tensors).to(device)
        
        with torch.no_grad():
            outputs = model(batch_input)
            predictions = outputs.argmax(dim=1).cpu().numpy()
            confidences = F.softmax(outputs, dim=1).max(dim=1)[0].cpu().numpy()
        
        for j, path in enumerate(batch_paths):
            results.append({
                'path': path,
                'prediction': predictions[j],
                'confidence': confidences[j]
            })
    
    return results

Performance Benchmarking and Common Issues

Understanding ResNet performance characteristics is essential for production deployment. Here’s a comprehensive benchmarking suite:

import time
import psutil
import numpy as np
from torchsummary import summary

def benchmark_model(model, input_size=(3, 224, 224), device='cuda', num_runs=100):
    model.eval()
    model.to(device)
    
    # Warmup
    dummy_input = torch.randn(1, *input_size).to(device)
    for _ in range(10):
        _ = model(dummy_input)
    
    # Benchmark inference time
    torch.cuda.synchronize()
    start_time = time.time()
    
    for _ in range(num_runs):
        with torch.no_grad():
            _ = model(dummy_input)
    
    torch.cuda.synchronize()
    end_time = time.time()
    
    avg_inference_time = (end_time - start_time) / num_runs * 1000  # ms
    
    # Memory usage
    if device == 'cuda':
        memory_usage = torch.cuda.max_memory_allocated() / 1024**2  # MB
    else:
        memory_usage = psutil.Process().memory_info().rss / 1024**2  # MB
    
    # Model statistics
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    results = {
        'avg_inference_time_ms': avg_inference_time,
        'memory_usage_mb': memory_usage,
        'total_parameters': total_params,
        'trainable_parameters': trainable_params,
        'model_size_mb': total_params * 4 / 1024**2  # Assuming float32
    }
    
    return results

# Performance comparison
models = {
    'ResNet-18': resnet18(),
    'ResNet-50': resnet50(),
    'ResNet-101': resnet101()
}

print("Model Performance Comparison:")
print("-" * 80)
print(f"{'Model':<12} {'Params(M)':<10} {'Inference(ms)':<15} {'Memory(MB)':<12} {'Size(MB)':<10}")
print("-" * 80)

for name, model in models.items():
    results = benchmark_model(model)
    print(f"{name:<12} {results['total_parameters']/1e6:<10.1f} "
          f"{results['avg_inference_time_ms']:<15.2f} "
          f"{results['memory_usage_mb']:<12.1f} "
          f"{results['model_size_mb']:<10.1f}")

Common Pitfalls and Troubleshooting

Training ResNet from scratch presents several challenges. Here are the most common issues and their solutions:

  • Vanishing Gradients: Even with skip connections, very deep ResNets can suffer from gradient issues. Use proper weight initialization and consider gradient clipping.
  • Memory Issues: Large batch sizes can cause OOM errors. Use gradient accumulation or reduce batch size with learning rate scaling.
  • Training Instability: Learning rate too high can cause loss explosions. Start with lower learning rates and use warmup schedules.
  • Poor Convergence: Incorrect data normalization or augmentation can hurt performance. Verify preprocessing matches pretrained model statistics.
  • Overfitting: Deep networks overfit easily on small datasets. Use dropout, data augmentation, and regularization techniques.
# Debugging utilities
def check_gradient_flow(model):
    """Check for gradient flow issues in the model"""
    ave_grads = []
    max_grads = []
    layers = []
    
    for n, p in model.named_parameters():
        if p.requires_grad and p.grad is not None:
            layers.append(n)
            ave_grads.append(p.grad.abs().mean().item())
            max_grads.append(p.grad.abs().max().item())
    
    plt.figure(figsize=(12, 6))
    plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.7, label="max-gradient")
    plt.bar(np.arange(len(ave_grads)), ave_grads, alpha=0.7, label="mean-gradient")
    plt.hlines(0, 0, len(ave_grads)+1, lw=2, color="k")
    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(left=0, right=len(ave_grads))
    plt.ylim(bottom=-0.001, top=0.02)
    plt.xlabel("Layers")
    plt.ylabel("Gradient")
    plt.title("Gradient Flow")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def validate_data_loading(train_loader):
    """Validate data loader configuration"""
    batch = next(iter(train_loader))
    images, labels = batch
    
    print(f"Batch shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")
    print(f"Image range: [{images.min():.3f}, {images.max():.3f}]")
    print(f"Labels range: [{labels.min()}, {labels.max()}]")
    print(f"Data type: {images.dtype}")
    
    # Check for NaN or infinite values
    if torch.isnan(images).any():
        print("WARNING: NaN values found in images!")
    if torch.isinf(images).any():
        print("WARNING: Infinite values found in images!")

For high-performance training and deployment, consider using powerful hardware configurations available through dedicated servers with multiple GPUs and high-memory configurations. For development and experimentation, VPS instances with GPU support provide cost-effective solutions for smaller-scale ResNet training and inference tasks.

The complete ResNet implementation provides a solid foundation for computer vision projects. Key considerations include proper initialization, appropriate learning rate scheduling, and careful attention to batch normalization placement. For production deployment, consider model quantization, TorchScript compilation, and batch inference optimization to achieve optimal performance. The modular design allows easy extension to custom architectures and domain-specific modifications while maintaining the core residual learning principles that make ResNet so effective.

Additional resources for deeper understanding include the original ResNet paper at https://arxiv.org/abs/1512.03385 and the PyTorch documentation for advanced training techniques at https://pytorch.org/docs/stable/index.html.



This article incorporates information and material from various online sources. We acknowledge and appreciate the work of all original authors, publishers, and websites. While every effort has been made to appropriately credit the source material, any unintentional oversight or omission does not constitute a copyright infringement. All trademarks, logos, and images mentioned are the property of their respective owners. If you believe that any content used in this article infringes upon your copyright, please contact us immediately for review and prompt action.

This article is intended for informational and educational purposes only and does not infringe on the rights of the copyright owners. If any copyrighted material has been used without proper credit or in violation of copyright laws, it is unintentional and we will rectify it promptly upon notification. Please note that the republishing, redistribution, or reproduction of part or all of the contents in any form is prohibited without express written permission from the author and website owner. For permissions or further inquiries, please contact us.

Leave a reply

Your email address will not be published. Required fields are marked