Batch Normalization

Normalizing layer inputs to accelerate deep network training

Batch Normalization (BatchNorm), introduced by Ioffe and Szegedy in 2015, is one of the most important techniques in deep learning. It normalizes layer inputs during training, enabling faster convergence, higher learning rates, and more stable training of deep networks.

The Problem: Internal Covariate Shift

As training progresses, the distribution of each layer’s inputs changes because the preceding layers’ weights change. This is called internal covariate shift.

Networks must constantly adapt to new input distributions, slowing training.

The Solution: Normalize Each Mini-Batch

For each feature in a layer, normalize across the mini-batch:

x^i=xiμBσB2+ϵ\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}

where:

  • μB=1mi=1mxi\mu_B = \frac{1}{m}\sum_{i=1}^m x_i (batch mean)
  • σB2=1mi=1m(xiμB)2\sigma_B^2 = \frac{1}{m}\sum_{i=1}^m (x_i - \mu_B)^2 (batch variance)
  • ϵ\epsilon is a small constant for numerical stability

Learnable Parameters

Simply normalizing would limit what the layer can represent. BatchNorm adds learnable scale (γ\gamma) and shift (β\beta):

yi=γx^i+βy_i = \gamma \hat{x}_i + \beta

This allows the network to learn to undo normalization if needed.

Interactive Demo

Visualize how BatchNorm stabilizes activations during training:

Batch Normalization Effect

Epoch: 0/20
Activation Distribution
-20+2
Mean (μ)
0.15
Target: 0
Std Dev (σ)
0.57
Target: 1
Variance (σ²)
0.33
Target: 1
BatchNorm Transform
ŷ = γ · (x - μ) / σ + β
γ (scale) and β (shift) are learnable parameters
With BatchNorm: Activations stay centered (μ≈0) with unit variance (σ≈1) throughout training. This stabilizes gradients and allows higher learning rates.

Where to Apply BatchNorm

Without BatchNorm:    Input → Linear → Activation → Linear → ...
With BatchNorm:       Input → Linear → BatchNorm → Activation → Linear → ...

Typically applied after the linear transformation but before the activation function.

Training vs. Inference

During training: Use mini-batch statistics (μB\mu_B, σB2\sigma_B^2)

During inference: Use running averages computed during training:

μrunning=αμrunning+(1α)μB\mu_{running} = \alpha \cdot \mu_{running} + (1-\alpha) \cdot \mu_B

This ensures deterministic outputs at inference time.

Why It Works

The original paper attributed success to reducing internal covariate shift, but later research suggests other factors:

  1. Smoother loss landscape: BatchNorm makes the optimization surface smoother, allowing larger learning rates
  2. Gradient flow: Normalization prevents gradients from vanishing or exploding
  3. Regularization: Mini-batch noise acts as a regularizer

Benefits

BenefitExplanation
Faster training10-14x fewer training steps
Higher learning ratesStable training with larger steps
Reduced initialization sensitivityLess dependent on weight initialization
Regularization effectReduces need for dropout

BatchNorm Variants

VariantNormalization DimensionUse Case
BatchNormAcross batchCNNs (requires large batches)
LayerNormAcross featuresTransformers, RNNs
InstanceNormPer sample, per channelStyle transfer
GroupNormGroups of channelsSmall batch sizes

The Forward Pass

def batch_norm(x, gamma, beta, eps=1e-5):
    # x shape: (batch_size, features)
    mu = x.mean(dim=0)
    var = x.var(dim=0)

    x_norm = (x - mu) / torch.sqrt(var + eps)

    return gamma * x_norm + beta

Limitations

  1. Batch size dependency: Small batches give noisy statistics
  2. Not suited for RNNs: Variable sequence lengths complicate batch statistics
  3. Train/test discrepancy: Different behavior in train vs. inference modes

For transformers and RNNs, Layer Normalization is preferred.

The Math: Backward Pass

Gradients flow through normalization:

Lγ=iLyix^i\frac{\partial L}{\partial \gamma} = \sum_i \frac{\partial L}{\partial y_i} \cdot \hat{x}_i Lβ=iLyi\frac{\partial L}{\partial \beta} = \sum_i \frac{\partial L}{\partial y_i}

The gradient through x^\hat{x} involves the chain rule through mean and variance.

Historical Impact

BatchNorm enabled:

  • Training of much deeper networks (ResNet’s 152 layers)
  • Higher learning rates (faster experimentation)
  • Reduced hyperparameter sensitivity
  • Standard component in nearly all modern architectures

Key Papers

Found an error or want to contribute? Edit on GitHub