Skip to content

Normalization

12 min Intermediate gen-ai Interview: 55%

LayerNorm, BatchNorm, RMSNorm: what they do, when to use them, and Pre-Norm vs Post-Norm

💼 55% of architecture interviews
Interview Relevance
55% of architecture interviews
🏭 Understanding transformer architecture
Production Impact
Powers systems at Understanding transformer architecture
RMSNorm ~7% faster than LayerNorm
Performance
RMSNorm ~7% faster than LayerNorm query improvement

TL;DR

Normalization keeps activations in a reasonable range during training. LayerNorm is standard for transformers, RMSNorm is faster and used in modern LLMs like Llama. Pre-norm (normalize before sublayer) is more stable than Post-norm for deep networks.

Visual Overview

THE PROBLEM: INTERNAL COVARIATE SHIFT
+-----------------------------------------------------------+
|                                                           |
|   During training, each layer's input distribution        |
|   changes as previous layers update.                      |
|                                                           |
|   Epoch 1: Layer 3 receives inputs with mean=0, std=1     |
|   Epoch 2: Layer 2 weights changed -> Layer 3 now sees    |
|            mean=0.5, std=2                                |
|   Epoch 3: More drift -> Layer 3 sees mean=1.2, std=3.5   |
|                                                           |
|   Layer 3 keeps having to re-adapt to shifting inputs.    |
|   Training is slower and less stable.                     |
|                                                           |
|   SOLUTION: Normalize activations to have consistent      |
|   statistics.                                             |
|                                                           |
+-----------------------------------------------------------+

Batch Normalization

Normalizes across the batch dimension. Each feature is normalized using statistics from the current batch.

BATCH NORMALIZATION
+-----------------------------------------------------------+
|                                                           |
|   Input: x with shape (batch_size, features)              |
|                                                           |
|   For each feature f:                                     |
|     mu_f = mean(x[:, f])        # mean across batch       |
|     sigma_f = std(x[:, f])      # std across batch        |
|                                                           |
|     x_norm[:, f] = (x[:, f] - mu_f) / (sigma_f + eps)     |
|                                                           |
|   Then apply learnable scale and shift:                   |
|     output = gamma x x_norm + beta                        |
|                                                           |
|   gamma and beta are learned per feature.                 |
|                                                           |
+-----------------------------------------------------------+

VISUAL: NORMALIZE ACROSS BATCH
+-----------------------------------------------------------+
|                                                           |
|         Feature 1   Feature 2   Feature 3                 |
|        +----------+----------+----------+                 |
| Batch 1|    2.1   |    0.5   |   -1.2   |                 |
|        +----------+----------+----------+                 |
| Batch 2|    1.8   |    0.7   |   -0.9   |  <- Normalize   |
|        +----------+----------+----------+     down each   |
| Batch 3|    2.3   |    0.4   |   -1.1   |     column      |
|        +----------+----------+----------+                 |
|             v          v          v                       |
|          mu=2.07    mu=0.53   mu=-1.07                    |
|                                                           |
+-----------------------------------------------------------+

When it works well:

  • CNNs (computer vision)
  • Large batch sizes (stable statistics)
  • Training (has batch to compute stats)

Problems:

  • Needs batch statistics at inference (use running average)
  • Small batches -> noisy statistics -> unstable
  • Batch size 1 -> undefined (no batch to normalize over)

Layer Normalization

Normalizes across the feature dimension. Each sample is normalized independently.

LAYER NORMALIZATION
+-----------------------------------------------------------+
|                                                           |
|   Input: x with shape (batch_size, features)              |
|                                                           |
|   For each sample i:                                      |
|     mu_i = mean(x[i, :])        # mean across features    |
|     sigma_i = std(x[i, :])      # std across features     |
|                                                           |
|     x_norm[i, :] = (x[i, :] - mu_i) / (sigma_i + eps)     |
|                                                           |
|   Then apply learnable scale and shift:                   |
|     output = gamma x x_norm + beta                        |
|                                                           |
+-----------------------------------------------------------+

VISUAL: NORMALIZE ACROSS FEATURES
+-----------------------------------------------------------+
|                                                           |
|         Feature 1   Feature 2   Feature 3                 |
|        +----------+----------+----------+                 |
| Batch 1|    2.1   |    0.5   |   -1.2   | -> Normalize    |
|        +----------+----------+----------+    this row     |
| Batch 2|    1.8   |    0.7   |   -0.9   | -> Normalize    |
|        +----------+----------+----------+    this row     |
| Batch 3|    2.3   |    0.4   |   -1.1   | -> Normalize    |
|        +----------+----------+----------+    this row     |
|                                                           |
+-----------------------------------------------------------+

When it works well:

  • Transformers (the standard)
  • RNNs, LSTMs
  • Any batch size (including 1)
  • Inference (no batch dependency)

Why transformers use LayerNorm:

  • Sequence length varies -> batch statistics meaningless
  • Inference often batch_size=1
  • Each token normalized independently

RMSNorm (Root Mean Square Normalization)

Simplified LayerNorm: only variance normalization, no mean centering.

RMSNORM
+-----------------------------------------------------------+
|                                                           |
|   Standard LayerNorm:                                     |
|     x_norm = (x - mean(x)) / std(x)                       |
|                                                           |
|   RMSNorm:                                                |
|     x_norm = x / RMS(x)                                   |
|                                                           |
|     where RMS(x) = sqrt(mean(x^2))                        |
|                                                           |
|   No mean subtraction. Just scale by root-mean-square.    |
|                                                           |
+-----------------------------------------------------------+

Why it works:

  • Mean centering turns out to be less important than variance scaling
  • Removing mean computation saves ~7% training time
  • Quality is equivalent or better in practice

Used in: Llama, Llama 2, Mistral, most modern LLMs

# RMSNorm implementation
def rmsnorm(x, weight, eps=1e-6):
    rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + eps)
    return weight * (x / rms)

Pre-Norm vs Post-Norm

Where you place normalization matters for training stability.

POST-NORM (Original Transformer)
+-----------------------------------------------------------+
|                                                           |
|   x = x + Attention(x)                                    |
|   x = LayerNorm(x)         <- Norm AFTER residual         |
|   x = x + FFN(x)                                          |
|   x = LayerNorm(x)                                        |
|                                                           |
|   Problem: Gradients must flow through LayerNorm          |
|            Can cause instability in deep networks         |
|                                                           |
+-----------------------------------------------------------+

PRE-NORM (Modern Transformers)
+-----------------------------------------------------------+
|                                                           |
|   x = x + Attention(LayerNorm(x))    <- Norm BEFORE       |
|   x = x + FFN(LayerNorm(x))                               |
|                                                           |
|   Advantages:                                             |
|     - Residual stream is "clean" (just additions)         |
|     - Gradients flow directly through residual path       |
|     - More stable for deep networks                       |
|     - Easier to train without careful LR tuning           |
|                                                           |
+-----------------------------------------------------------+

Which to use: Pre-norm for new models. Post-norm only if replicating original BERT/GPT-2.


Comparison Table

AspectBatchNormLayerNormRMSNorm
Normalizes acrossBatchFeaturesFeatures
Works with batch=1NoYesYes
Needs running statsYesNoNo
Mean centeringYesYesNo
SpeedBaselineBaseline~7% faster
Used inCNNsTransformersModern LLMs

Debugging Normalization Issues

TRAINING INSTABILITY (LOSS SPIKES)
+-----------------------------------------------------------+
|                                                           |
|   Symptoms:                                               |
|     - Loss suddenly spikes during training                |
|     - Gradients explode intermittently                    |
|                                                           |
|   Causes:                                                 |
|     - Post-norm architecture with deep network            |
|     - Missing normalization somewhere                     |
|     - Norm placed incorrectly                             |
|                                                           |
|   Debug steps:                                            |
|     1. Switch to pre-norm if using post-norm              |
|     2. Check every sublayer has normalization             |
|     3. Verify norm is before attention/FFN, not after     |
|     4. Reduce learning rate                               |
|                                                           |
+-----------------------------------------------------------+

ACTIVATIONS GROWING UNBOUNDED
+-----------------------------------------------------------+
|                                                           |
|   Symptoms:                                               |
|     - Activation magnitudes grow over layers              |
|     - Eventually overflow to NaN                          |
|                                                           |
|   Causes:                                                 |
|     - Missing normalization layer                         |
|     - Residual accumulation without norm                  |
|     - Wrong norm dimension                                |
|                                                           |
|   Debug steps:                                            |
|     1. Print activation statistics per layer              |
|     2. Verify norm is applied (gamma, beta params exist)  |
|     3. Check norm dimension matches input shape           |
|                                                           |
+-----------------------------------------------------------+

When This Matters

SituationWhat to know
Reading transformer codeLayerNorm before attention/FFN (pre-norm)
Understanding Llama/MistralRMSNorm, not LayerNorm
Training instabilitySwitch to pre-norm, check norm placement
Batch size constraintsLayerNorm works with any batch size
Optimizing inference speedRMSNorm is slightly faster
Porting CNN techniquesBatchNorm doesn’t work for transformers
Understanding model configs”norm_eps” is the epsilon in denominator