Knowledge Distillation Tutorial: Building Small, Fast Models that Perform

Hands-on knowledge distillation tutorial for compact models: concepts, PyTorch/Keras code, tuning tips, and deployment with quantization.

ASOasis
8 min read
Knowledge Distillation Tutorial: Building Small, Fast Models that Perform

Image used for representation purposes only.

Why knowledge distillation matters

Deploying deep learning on phones, browsers, wearables, and embedded devices demands models that are small, fast, and energy‑efficient—without giving up too much accuracy. Knowledge distillation (KD) is a proven strategy to transfer performance from a large “teacher” network to a compact “student” model by training the student to mimic the teacher’s behavior.

This tutorial walks through the core ideas, a practical PyTorch implementation, variants you can try, and a production‑minded workflow that includes quantization and deployment.

The core idea, in one paragraph

Standard training uses hard labels (e.g., class 7). KD augments this with soft targets from a teacher: the teacher’s probability distribution over classes reveals dark knowledge—relative class similarities (e.g., cat vs. fox). We compute a combined loss: hard‑label cross‑entropy plus a divergence between teacher and student soft predictions at temperature T. The temperature smooths distributions; a higher T reveals more structure.

When to use KD

  • Compressing a high‑accuracy model for on‑device inference.
  • Speeding up inference on CPU/NPUs when latency budgets are tight.
  • Improving small‑model accuracy when labeled data is limited or noisy.
  • Preserving accuracy after pruning or quantization.

Choosing teacher and student

  • Teacher: accurate, possibly overparameterized (e.g., a large ViT or ResNet) trained on the task/domain.
  • Student: architecture targeted at constraints—MobileNetV2/V3, ShuffleNet, EfficientNet‑Lite, small ResNets, or a custom CNN/Transformer with reduced width/depth.
  • Match output dimensions (same number of classes). For feature‑based KD, consider aligning intermediate channel sizes or use 1×1 adapters.

The standard KD loss

Let z_t be teacher logits, z_s student logits, y hard labels, T temperature, and α the distillation weight.

Loss = (1 − α) · CE(y, z_s) + α · T² · KL(softmax(z_t/T) || softmax(z_s/T))

Multiplying by T² preserves gradient scale. Typical values: T ∈ [2, 8], α ∈ [0.1, 0.9]. Start with T=4, α=0.5.

End‑to‑end PyTorch tutorial

Below is a minimal yet production‑oriented training script that distills a MobileNetV2 student from a ResNet‑50 teacher on CIFAR‑10. Adapt for your dataset.

# pip install torch torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# 1) Data
train_tfms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
val_tfms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

train_ds = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_tfms)
val_ds   = datasets.CIFAR10(root='./data', train=False, download=True, transform=val_tfms)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False, num_workers=4, pin_memory=True)

# 2) Teacher and student
teacher = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
teacher.fc = nn.Linear(2048, 10)
# Load your task-specific checkpoint if available
# teacher.load_state_dict(torch.load('resnet50_cifar10.pt'))
for p in teacher.parameters():
    p.requires_grad = False  # freeze teacher
teacher.eval().to(DEVICE)

student = models.mobilenet_v2(weights=None)
student.classifier[1] = nn.Linear(student.last_channel, 10)
student.to(DEVICE)

# 3) KD loss
class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.T = temperature
        self.alpha = alpha
        self.kl = nn.KLDivLoss(reduction='batchmean')
        self.ce = nn.CrossEntropyLoss()
    def forward(self, student_logits, teacher_logits, targets):
        T = self.T
        # Hard label loss
        ce_loss = self.ce(student_logits, targets)
        # Soft label loss: KL(teacher || student)
        s_log_probs = F.log_softmax(student_logits / T, dim=1)
        t_probs     = F.softmax(teacher_logits / T, dim=1)
        kd_loss = self.kl(s_log_probs, t_probs) * (T * T)
        return (1 - self.alpha) * ce_loss + self.alpha * kd_loss, ce_loss.detach(), kd_loss.detach()

criterion = DistillationLoss(temperature=4.0, alpha=0.5)
optimizer = torch.optim.AdamW(student.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=='cuda'))

# 4) Training and evaluation loops
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

def train_kd(epochs=20):
    for epoch in range(1, epochs+1):
        student.train()
        total_loss, total_ce, total_kd, n = 0.0, 0.0, 0.0, 0
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            with torch.no_grad():
                t_logits = teacher(x)
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=(DEVICE=='cuda')):
                s_logits = student(x)
                loss, ce_l, kd_l = criterion(s_logits, t_logits, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            total_loss += loss.item() * y.size(0)
            total_ce   += ce_l.item() * y.size(0)
            total_kd   += kd_l.item() * y.size(0)
            n += y.size(0)
        scheduler.step()
        val_acc = evaluate(student, val_loader)
        print(f"Epoch {epoch:03d} | loss={total_loss/n:.4f} ce={total_ce/n:.4f} kd={total_kd/n:.4f} | val_acc={val_acc:.4f}")
    return student

student = train_kd(epochs=30)

# 5) Model size and latency profiling
param_count = sum(p.numel() for p in student.parameters())
print(f"Student parameters: {param_count/1e6:.2f}M")

student.eval()
example = torch.randn(1,3,32,32).to(DEVICE)
with torch.no_grad():
    # warmup
    for _ in range(20): student(example)
    import time
    iters = 100
    start = time.time()
    for _ in range(iters):
        _ = student(example)
        if DEVICE=='cuda': torch.cuda.synchronize()
    elapsed = time.time() - start
    print(f"Avg latency: {elapsed/iters*1000:.2f} ms")

Notes

  • For your dataset, fine‑tune the teacher on the same data before distilling.
  • Use mixed precision (already on) to speed up training on GPUs.
  • If memory allows, increase batch size for a more stable KD signal.

Feature‑based and attention transfer (optional)

You can intensify KD by matching internal representations. Example: L2 loss between a teacher’s intermediate feature map and the student’s corresponding feature (via a 1×1 conv to align channels).

# Example: add a feature hint loss from teacher.layer3 to student.features[14].
from contextlib import contextmanager

@contextmanager
def feature_hook(module):
    feat = {}
    def hook(_, __, output):
        feat['x'] = output
    h = module.register_forward_hook(hook)
    try:
        yield feat
    finally:
        h.remove()

adapter = nn.Conv2d(in_channels=96, out_channels=1024, kernel_size=1).to(DEVICE)  # align channels as needed
feat_loss_fn = nn.MSELoss()

# Inside your training step, around forward passes:
with feature_hook(teacher.layer3) as Tfeat, feature_hook(student.features[14]) as Sfeat:
    t_logits = teacher(x)
    s_logits = student(x)
    # After forward, use Tfeat['x'] and Sfeat['x']
    f_loss = feat_loss_fn(adapter(Sfeat['x']), Tfeat['x'].detach())
    loss, ce_l, kd_l = criterion(s_logits, t_logits, y)
    loss = loss + 0.05 * f_loss  # small weight for stability

Other popular variants

  • Attention transfer: match spatial attention maps (mean of squared activations across channels).
  • Relational KD (RKD): match pairwise or triplet distances between samples in feature space.
  • Contrastive representation distillation (CRD): contrastive loss that pulls student features towards teacher positives and away from negatives.

Hyperparameters that matter

  • Temperature T: higher T reveals more nuanced probabilities but can blur signal if too high. Try 2, 4, 6, 8.
  • α (distillation weight): target 0.3–0.7. If labels are noisy, increase α.
  • Learning rate and weight decay: KD often benefits from slightly lower LR vs. training from scratch.
  • Scheduling: consider cosine LR and optionally warm‑up. You may also anneal α from high to low across epochs.

Data strategies that boost KD

  • Use strong but label‑preserving augmentations: RandomCrop/Flip, RandAugment, MixUp/CutMix (tune α,β carefully with KD; start with MixUp α=0.2).
  • Semi‑supervised KD: use teacher to pseudo‑label unlabeled data; filter by confidence or keep all with soft targets.
  • Curriculum: start with higher T and α, gradually reduce as student stabilizes.

Quantization and pruning with KD

KD pairs well with compression:

  • Post‑Training Quantization (PTQ): distill first, then quantize weights/activations to INT8; run a small calibration set to collect activation ranges.
  • Quantization‑Aware Training (QAT): insert fake‑quant nodes and continue training with KD so the teacher guides the int8‑simulated student.
  • Pruning: prune the student (structured channel pruning) and fine‑tune with KD to recover accuracy.

Minimal QAT flow in PyTorch (FX/prepare_qat) is beyond this snippet, but the recipe is: initialize quantizable student -> prepare_qat -> train with KD -> convert -> evaluate.

Evaluation beyond accuracy

For edge deployments, track:

  • Latency P50/P95 on target hardware (CPU, mobile GPU, NPU).
  • Memory footprint (parameters × dtype; activations peak).
  • Energy per inference (if available) and throughput.
  • Calibration drift after quantization; validate on real inputs.

Minimal Keras/TensorFlow KD example

# pip install tensorflow
import tensorflow as tf
from tensorflow import keras

T = 4.0
alpha = 0.5

class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.student = student
        self.teacher = teacher
        self.ce = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        self.kld = keras.losses.KLDivergence()

    def train_step(self, data):
        x, y = data
        teacher_logits = self.teacher(x, training=False)
        with tf.GradientTape() as tape:
            student_logits = self.student(x, training=True)
            ce_loss = self.ce(y, student_logits)
            t_probs = tf.nn.softmax(teacher_logits / T)
            s_log_probs = tf.nn.log_softmax(student_logits / T)
            kd_loss = self.kld(t_probs, tf.exp(s_log_probs)) * (T*T)
            loss = (1 - alpha) * ce_loss + alpha * kd_loss
        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))
        return {"loss": loss, "ce": ce_loss, "kd": kd_loss}

# Example: CIFAR-10 student/teacher
teacher = keras.applications.ResNet50(weights=None, classes=10, input_shape=(32,32,3))
student = keras.applications.MobileNetV2(weights=None, classes=10, input_shape=(32,32,3))

teacher.compile(optimizer='adam', loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True))
# teacher.fit(...)  # train or load weights

distiller = Distiller(student, teacher)
distiller.compile(optimizer=keras.optimizers.Adam(3e-4))
# distiller.fit(train_ds, validation_data=val_ds, epochs=30)

Troubleshooting checklist

  • Student plateaus far below teacher: increase α or T; ensure teacher is accurate on the same data; consider feature‑based KD.
  • Overfitting: add regularization (weight decay, dropout), stronger augmentation, or semi‑supervised KD with unlabeled data.
  • Unstable training or exploding loss: lower LR; clamp logits; reduce T.
  • Post‑quantization accuracy drop: perform QAT with KD; recalibrate with a more diverse set; check per‑channel quant.
  • Domain shift between train and deploy: augment with realistic noise and distributions; consider Test‑Time Augmentation during validation to sanity‑check.

Reproducibility and instrumentation

  • Seed everything (torch, numpy, dataloader workers) but expect some nondeterminism on GPUs.
  • Log α, T, LR, and validation curves; save best checkpoints.
  • Track teacher–student KL across epochs; it should generally decrease.

Deployment tips

  • Export: TorchScript or ONNX from PyTorch; TFLite from Keras/TF.
  • Target optimizations: operator fusion, int8 kernels, Winograd/FFT for convs, or vendor‑provided NN libraries.
  • Validate real‑world latency on device; desktop timing is a poor proxy for mobile.

A practical recipe to start

  1. Fine‑tune a strong teacher on your data. 2) Pick a mobile‑friendly student. 3) Train with KD (T=4, α=0.5). 4) Add light feature KD if needed. 5) Quantize (PTQ), measure. 6) If accuracy drops, switch to QAT+KD. 7) Profile on target hardware and iterate.

Conclusion

Knowledge distillation is a simple, powerful lever to ship compact models with competitive accuracy. With a good teacher, thoughtful hyperparameters, and careful evaluation of latency and memory, you can meet strict on‑device budgets without sacrificing user experience.

Related Posts