Layer Normalization

Normalizing across features for sequence models and Transformers

Layer Normalization (LayerNorm) normalizes across the feature dimension rather than the batch dimension. This makes it ideal for Transformers and RNNs where batch statistics are problematic.

The Key Difference

BatchNorm: Normalize across batch for each feature

x^n,d=xn,dμdσd2+ϵ,μd=1Nnxn,d\hat{x}_{n,d} = \frac{x_{n,d} - \mu_d}{\sqrt{\sigma_d^2 + \epsilon}}, \quad \mu_d = \frac{1}{N}\sum_n x_{n,d}

LayerNorm: Normalize across features for each sample

x^n,d=xn,dμnσn2+ϵ,μn=1Ddxn,d\hat{x}_{n,d} = \frac{x_{n,d} - \mu_n}{\sqrt{\sigma_n^2 + \epsilon}}, \quad \mu_n = \frac{1}{D}\sum_d x_{n,d}

The Algorithm

For an input xRDx \in \mathbb{R}^D:

μ=1Dd=1Dxd,σ2=1Dd=1D(xdμ)2\mu = \frac{1}{D}\sum_{d=1}^{D} x_d, \quad \sigma^2 = \frac{1}{D}\sum_{d=1}^{D} (x_d - \mu)^2 x^=xμσ2+ϵ,y=γx^+β\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}, \quad y = \gamma \odot \hat{x} + \beta

where γ,βRD\gamma, \beta \in \mathbb{R}^D are learnable parameters.

Interactive Visualization

Compare BatchNorm and LayerNorm normalization patterns:

Normalization Comparison

Original
0.1
2.7
-0.6
0.7
0.3
-0.3
-0.8
1.8
1.4
1.2
-0.4
0.1
0.8
2.3
0.2
-0.4
2.0
0.7
-0.8
-0.2
0.9
1.2
0.1
-0.7
6 features →
Normalized
-0.3
2.1
-1.0
0.2
-0.2
-0.8
-1.4
1.3
0.9
0.7
-1.0
-0.5
-0.1
1.5
-0.7
-1.4
1.1
-0.2
-1.2
-0.4
1.1
1.5
0.1
-1.1

LayerNorm: Each row (sample) is normalized independently. μ≈0, σ²≈1 per row.

Why LayerNorm for Transformers?

  1. Batch independence: Each sequence is normalized independently
  2. Variable sequence lengths: No batch statistics needed
  3. Inference consistency: Same computation at train and test time
  4. Autoregressive generation: Works with batch size 1

Pre-Norm vs Post-Norm

Post-Norm (original Transformer):

x=LayerNorm(x+Sublayer(x))x' = \text{LayerNorm}(x + \text{Sublayer}(x))

Pre-Norm (modern preference):

x=x+Sublayer(LayerNorm(x))x' = x + \text{Sublayer}(\text{LayerNorm}(x))

Pre-Norm enables better gradient flow and often converges faster.

RMSNorm: A Simpler Alternative

Remove the mean centering:

x^=x1Ddxd2+ϵγ\hat{x} = \frac{x}{\sqrt{\frac{1}{D}\sum_d x_d^2 + \epsilon}} \cdot \gamma

Used in LLaMA and many modern LLMs—simpler and often works just as well.

Comparison Table

AspectBatchNormLayerNormRMSNorm
Normalizes overBatchFeaturesFeatures
Learnable params2D2DD
Mean centeringYesYesNo
Best forCNNsTransformersLLMs

Common Placements in Transformers

# Pre-norm Transformer block
def forward(x):
    x = x + self.attn(self.norm1(x))
    x = x + self.ffn(self.norm2(x))
    return x

Key Insight

LayerNorm’s batch independence makes it essential for:

  • Autoregressive generation (batch size 1)
  • Variable-length sequences
  • Distributed training across sequence dimension
Found an error or want to contribute? Edit on GitHub