Exponential Moving Average (EMA) in PyTorch
https://github.com/lucidrains/ema-pytorch
Exponential Moving Average (EMA) in PyTorch
TL;DR
EMA (Exponential Moving Average) maintains a smoothed version of model weights:
It stabilizes training and usually gives better validation/inference results, especially for generative models (diffusion, GANs) and large models.
The hard rule
-
EMA parameters must never receive gradients and must never be part of the optimizer.
-
EMA is a non-trainable, read-only copy of the weights. But: it is a dynamically updated copy, not a static snapshot.
┌────────────┐ │ loss │ └─────┬──────┘ ↓ ┌────────────┐ │ raw model │ ← gradients flow ONLY here └─────┬──────┘ ↓ optimizer.step() ↓ (after step, no grad) ┌────────────┐ │ EMA model │ ← read-only └────────────┘
1) What EMA Is (Conceptually)
EMA maintains a smoothed copy of model parameters over training time.
Mathematically:
Where:
- \theta_{\text{current}}$$: current model parameters
- \alpha$$: decay factor (typically close to 1, e.g. 0.999 or 0.9999)
2) Intuition: Why EMA Helps
Stochastic optimizers (SGD, Adam, AdamW) produce noisy parameter updates due to:
- minibatch randomness
- non-convex loss landscapes
- oscillations around minima
- aggressive learning rates
Parameter evolution often looks like:
bounce → overshoot → oscillate → drift → settle → oscillate
EMA acts like a temporal low-pass filter, producing:
smooth → stable → consolidated weights
3) Why EMA Improves Performance
(A) Better generalization
EMA weights often lie near the center of flat minima, which tend to generalize better than sharp minima.
(B) Training stability
EMA is particularly helpful for unstable training regimes:
- diffusion models
- GANs
- large transformers
- multimodal and video models
- world models
© Better evaluation & inference
Common pattern:
- Train using raw model weights
- Evaluate / infer using EMA weights
In diffusion models, the EMA model is often the only model used at inference time.
4) How EMA Is Implemented in PyTorch
Step 1: Create an EMA copy of the model
1 | import copy |
This creates a shadow model that does not receive gradients.
Step 2: Apply the EMA update rule
1 | import torch |
This implements:
Step 3: Integrate into the training loop
1 | for batch in dataloader: |
EMA updates are performed after each optimizer step.
Step 4: Use EMA for evaluation / inference
1 | ema_model.eval() |
Use ema_model, not the raw model, when measuring validation metrics or running inference.
5) Why EMA Instead of Just a Smaller Learning Rate?
- Small learning rate reduces motion (slower learning, less exploration).
- EMA smooths motion without slowing down learning.
EMA lets you:
- train fast (larger LR, better exploration)
- evaluate stably (smoothed parameters)
So you get speed + stability.
6) Where EMA Is Essentially Mandatory
EMA is common or “expected” in:
- Diffusion models (DDPM, Stable Diffusion, EDM, video diffusion)
- GANs (e.g. StyleGAN uses EMA for the generator)
- Large transformers and foundation models
- Multimodal and long-horizon models
7) EMA vs SWA (Stochastic Weight Averaging)
| Method | Averaging style | Key property |
|---|---|---|
| EMA | Exponential | Recent weights matter more |
| SWA | Uniform | All weights equally weighted |
EMA adapts faster to learning dynamics (because newer weights have higher influence).
8) Choosing the Decay Hyperparameter
Typical values:
| Model type | Decay |
|---|---|
| GANs | 0.999 |
| Diffusion | 0.9999 |
| Transformers | 0.999 |
| Video models | 0.9999 |
Higher decay → slower update → smoother but less reactive EMA.
A useful intuition is the “effective window length”:
So:
- window steps
- window steps
9) Deeper (Research-Level) Intuition
EMA can be interpreted as:
- parameter trajectory averaging
- variance reduction of gradient noise
- an implicit regularizer
- a bias toward flatter minima
It’s not just cosmetic smoothing — it can materially improve the geometry of the solution you end up evaluating.
10) Is EMA basically 2 copies of the weights? What is the actual overhead?
Yes. Literally yes.
You have:
- Raw model weights → optimized by the optimizer
- EMA model weights → a shadow copy updated by EMA
So memory-wise:
EMA ≈ 2× model parameters
No tricks, no compression.
(A) Memory overhead
- Parameters are duplicated
- No optimizer state for EMA
- No gradients
- No activations
So overhead is roughly:
| Component | Raw model | EMA model |
|---|---|---|
| Parameters | ✅ | ✅ |
| Gradients | ✅ | ❌ |
| Optimizer states | ✅ | ❌ |
| Activations | runtime only | ❌ |
Rule of thumb:
EMA adds ~+100% parameter memory, but much less than +100% total training memory.
For large models:
- optimizer states (Adam = 2–3× params)
- activations dominate memory anyway
So EMA is usually “cheap enough”.
(B) Compute overhead
EMA update cost per step:
1 | for p, ema_p in zip(...): |
That’s:
- one multiply
- one add
- per parameter
Compared to backprop:
EMA cost ≈ negligible (<1%)
11) Usage Cheatsheet
| Stage / Purpose | Model Used | Why | Notes |
|---|---|---|---|
| Training (forward / backward) | Raw model | Only raw weights are optimized via backprop | EMA must never enter backprop |
| Optimizer step | Raw model | Updates trainable parameters | EMA has no optimizer state |
| EMA update | EMA model | Running average of raw weights | After optimizer.step() |
| Training loss logging | Raw model | Reflects actual optimization signal | EMA loss may lag |
| Training accuracy (optional) | EMA model | Smoother, less noisy curves | Log as auxiliary metric |
| Visualization / sampling | EMA model | Cleaner, more stable outputs | Standard in diffusion / GANs |
| Validation | EMA model | Better generalization estimate | Use EMA by default |
| Model selection (“best”) | EMA model | Less variance, more reliable | Save best EMA checkpoint |
| Final evaluation / test | EMA model | Represents consolidated model | Raw model often discarded |
| Deployment / inference | EMA model | Higher quality, more stable | Especially critical for generative models |
Checkpointing with EMA
| What you want | What to save | Can resume training? |
|---|---|---|
| Continue training | Raw + EMA + optimizer | ✅ Yes |
| Only inference / release | EMA only | ❌ No |
| Best validation model | EMA only | ❌ No (but fine for eval) |