Transfer Learning with PyTorch: Build an Image Classifier in 30 Minutes

Learn how to use pretrained models like ResNet and EfficientNet to build custom image classifiers with PyTorch. Feature extraction, fine-tuning, and training best practices.

Training a deep neural network from scratch takes weeks and thousands of labeled images. Most developers don’t have that time or data. Transfer learning solves this by letting you reuse models already trained on millions of images, then adapt them to your specific task in hours.

This tutorial shows you how to take a pretrained model like ResNet50, swap its classification head, and train it to recognize your own image categories. You’ll understand when to freeze layers, when to fine-tune them, and how to avoid common mistakes that waste training time.

Prerequisites

You need Python 3.8 or higher, PyTorch 2.0+, and torchvision. Install them with:

pip install torch torchvision pillow matplotlib

You should know basic PyTorch tensor operations and have trained a simple neural network before. Familiarity with convolutional layers helps but isn’t required.

For this tutorial, we’ll use a small dataset (50-100 images per class works fine). You can use your own images or download a dataset from Kaggle.

Step 1: What Is Transfer Learning and Why It Works

Transfer learning takes a model trained on one task and applies it to a different but related task. A model trained to recognize 1,000 ImageNet categories has learned to detect edges, textures, shapes, and object parts. These features work for most image recognition tasks.

Think of it like learning a new language. If you already speak Spanish, learning Italian is faster because both share similar grammar and vocabulary. Your brain transfers existing knowledge to the new task. Neural networks do the same.

Two approaches exist: feature extraction and fine-tuning. Feature extraction freezes the pretrained layers and only trains the new classifier head. Fine-tuning updates all layers (or selected ones) with a small learning rate. We’ll cover both methods.

Transfer learning reduces training time by 10-100x compared to training from scratch. It also performs better when you have limited data because the pretrained features already encode visual patterns.

Step 2: Loading Pretrained Models (ResNet, VGG, EfficientNet)

PyTorch’s torchvision library includes dozens of pretrained models. Here’s how to load the most popular ones:

import torch
import torchvision.models as models

# Load ResNet50 with ImageNet weights
resnet50 = models.resnet50(weights='IMAGENET1K_V2')

# Load VGG16
vgg16 = models.vgg16(weights='IMAGENET1K_V1')

# Load EfficientNet-B0
efficientnet = models.efficientnet_b0(weights='IMAGENET1K_V1')

# Check the model architecture
print(resnet50)

Each model has a slightly different architecture. ResNet models use residual connections and are good all-around choices. VGG models are older and simpler but require more memory. EfficientNet models balance accuracy and speed.

The weights parameter specifies which pretrained weights to load. IMAGENET1K_V2 means weights trained on ImageNet with improved training recipes. Setting weights=None gives you a randomly initialized model (don’t do this for transfer learning).

You can also check available models:

# List all available models
all_models = dir(models)
pretrained_models = [m for m in all_models if not m.startswith('_')]
print(pretrained_models[:10])

Different models have different input size requirements. ResNet expects 224x224 images by default. EfficientNet variants use different sizes (B0 uses 224x224, B7 uses 600x600). Always check the model documentation.

Step 3: Feature Extraction vs Fine-Tuning

Feature extraction freezes the pretrained layers and only trains the final classification layer. This is fast and works well when your dataset is similar to ImageNet (photos of objects, animals, etc).

Fine-tuning updates some or all layers with a small learning rate. Use this when your dataset differs from ImageNet (medical images, satellite photos) or when you have enough data to improve the pretrained features.

Here’s how to set up feature extraction:

import torch.nn as nn

# Load pretrained ResNet50
model = models.resnet50(weights='IMAGENET1K_V2')

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Replace the final layer
# ResNet50's final layer is 'fc' (fully connected)
num_features = model.fc.in_features
num_classes = 5  # Your number of classes
model.fc = nn.Linear(num_features, num_classes)

# Only the new layer is trainable
print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

For fine-tuning, you can unfreeze all layers or just the last few:

# Load pretrained model
model = models.resnet50(weights='IMAGENET1K_V2')

# Replace final layer
num_features = model.fc.in_features
num_classes = 5
model.fc = nn.Linear(num_features, num_classes)

# Unfreeze the last residual block (layer4)
for param in model.layer4.parameters():
    param.requires_grad = True

# Or unfreeze everything
for param in model.parameters():
    param.requires_grad = True

A good strategy: start with feature extraction for a few epochs, then fine-tune the last few layers. This prevents the random classifier weights from corrupting the pretrained features during early training.

Step 4: Building a Custom Image Classifier

Let’s build a complete classifier for a 5-class problem (cats, dogs, birds, fish, and reptiles). First, organize your images:

dataset/
    train/
        cats/
            cat001.jpg
            cat002.jpg
            ...
        dogs/
            dog001.jpg
            ...
    val/
        cats/
            cat501.jpg
            ...
        dogs/
            dog501.jpg
            ...

Now create the data loaders:

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Data preprocessing
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Load datasets
data_dir = 'dataset'
image_datasets = {
    x: datasets.ImageFolder(f'{data_dir}/{x}', data_transforms[x])
    for x in ['train', 'val']
}

# Create data loaders
dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4)
    for x in ['train', 'val']
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(f"Classes: {class_names}")
print(f"Training images: {dataset_sizes['train']}")
print(f"Validation images: {dataset_sizes['val']}")

The normalization values [0.485, 0.456, 0.406] and [0.229, 0.224, 0.225] are ImageNet’s mean and standard deviation. Always use these when working with ImageNet pretrained models.

Data augmentation (RandomResizedCrop, RandomHorizontalFlip) helps prevent overfitting. Only apply augmentation to training data, not validation data.

Now set up the model:

import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model
model = models.resnet50(weights='IMAGENET1K_V2')
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(class_names))
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

Using a learning rate scheduler reduces the learning rate during training. This helps the model converge to better solutions.

Step 5: Training, Evaluation, and Saving Models

Here’s the training loop:

import time
import copy

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward pass + optimize only in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Deep copy the model if it's the best so far
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

# Train the model
model = train_model(model, criterion, optimizer, scheduler, num_epochs=25)

The training loop switches between training and evaluation modes. In training mode, dropout and batch normalization behave differently than in evaluation mode.

We save the model weights that achieve the best validation accuracy. This prevents overfitting to the training set.

To save and load your trained model:

# Save the entire model
torch.save(model.state_dict(), 'resnet50_custom.pth')

# Load it later
model = models.resnet50(weights=None)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(class_names))
model.load_state_dict(torch.load('resnet50_custom.pth'))
model = model.to(device)
model.eval()

To make predictions on new images:

from PIL import Image

def predict_image(image_path, model, transform):
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor)
        _, predicted = torch.max(outputs, 1)
        
    return class_names[predicted.item()]

# Test it
transform = data_transforms['val']
prediction = predict_image('test_image.jpg', model, transform)
print(f"Predicted class: {prediction}")

Common Pitfalls

Many beginners make these mistakes:

Not freezing layers properly. If you train all layers with a high learning rate, the pretrained features get corrupted. Start with feature extraction, then fine-tune carefully.

Wrong input normalization. ImageNet models expect inputs normalized with ImageNet statistics. Using different normalization values (or none at all) tanks performance. Always use the standard ImageNet mean and std.

Training too long. Transfer learning is fast. If your validation loss stops improving after 10 epochs, stop training. More epochs often lead to overfitting when working with small datasets.

Using the wrong learning rate. When fine-tuning, use a learning rate 10-100x smaller than when training from scratch. Start with 0.001 for feature extraction and 0.0001 for fine-tuning.

Ignoring class imbalance. If you have 500 cat images and 50 dog images, the model will bias toward cats. Use weighted sampling or class weights in your loss function to fix this.

Not using data augmentation. Small datasets overfit without augmentation. Random crops, flips, and color jittering give the model more varied examples to learn from.

Summary

Transfer learning lets you build accurate image classifiers with limited data and computing resources. Load a pretrained model from torchvision, replace the final layer with your own classifier, and train for a few epochs.

Start with feature extraction (frozen layers) if your dataset resembles ImageNet. Use fine-tuning if your images look nothing like natural photos. Always use proper input normalization and data augmentation.

The complete workflow: prepare your data folders, create data loaders with the right transforms, load a pretrained model, replace the classifier head, train with a learning rate scheduler, and save the best model weights. You can build a working classifier in 30 minutes once you have the data organized.

Transfer learning makes deep learning accessible to small teams. You don’t need a GPU cluster or months of training time. You just need a few hundred labeled images and the right approach.

Spread The Article

Share this guide

Send this article to your network or keep a copy of the direct link.

X Facebook LinkedIn Reddit Telegram

Discussion

Leave a comment

No comments yet

Be the first to start the conversation.