https://github.com/lucidrains/ema-pytorch

Exponential Moving Average (EMA) in PyTorch

TL;DR

EMA (Exponential Moving Average) maintains a smoothed version of model weights:

θEMAαθEMA+(1α)θcurrent\theta_{\text{EMA}} \leftarrow \alpha \cdot \theta_{\text{EMA}} + (1-\alpha)\cdot \theta_{\text{current}}

It stabilizes training and usually gives better validation/inference results, especially for generative models (diffusion, GANs) and large models.

The hard rule

  • EMA parameters must never receive gradients and must never be part of the optimizer.

  • EMA is a non-trainable, read-only copy of the weights. But: it is a dynamically updated copy, not a static snapshot.

      ┌────────────┐
      │    loss    │
      └─────┬──────┘
            ↓
      ┌────────────┐
      │ raw model  │  ← gradients flow ONLY here
      └─────┬──────┘
            ↓
      optimizer.step()
    
            ↓ (after step, no grad)
      ┌────────────┐
      │  EMA model │  ← read-only
      └────────────┘
    

1) What EMA Is (Conceptually)

EMA maintains a smoothed copy of model parameters over training time.

Mathematically:

θEMAαθEMA+(1α)θcurrent\theta_{\text{EMA}} \leftarrow \alpha \cdot \theta_{\text{EMA}} + (1 - \alpha) \cdot \theta_{\text{current}}

Where:

  • \theta_{\text{current}}$$: current model parameters
  • \alpha$$: decay factor (typically close to 1, e.g. 0.999 or 0.9999)

2) Intuition: Why EMA Helps

Stochastic optimizers (SGD, Adam, AdamW) produce noisy parameter updates due to:

  • minibatch randomness
  • non-convex loss landscapes
  • oscillations around minima
  • aggressive learning rates

Parameter evolution often looks like:

bounce → overshoot → oscillate → drift → settle → oscillate

EMA acts like a temporal low-pass filter, producing:

smooth → stable → consolidated weights


3) Why EMA Improves Performance

(A) Better generalization

EMA weights often lie near the center of flat minima, which tend to generalize better than sharp minima.

(B) Training stability

EMA is particularly helpful for unstable training regimes:

  • diffusion models
  • GANs
  • large transformers
  • multimodal and video models
  • world models

© Better evaluation & inference

Common pattern:

  • Train using raw model weights
  • Evaluate / infer using EMA weights

In diffusion models, the EMA model is often the only model used at inference time.


4) How EMA Is Implemented in PyTorch

Step 1: Create an EMA copy of the model

1
2
3
4
5
import copy

ema_model = copy.deepcopy(model)
for p in ema_model.parameters():
p.requires_grad_(False)

This creates a shadow model that does not receive gradients.


Step 2: Apply the EMA update rule

1
2
3
4
5
6
7
import torch

@torch.no_grad()
def ema_update(model, ema_model, decay: float):
for param, ema_param in zip(model.parameters(), ema_model.parameters()):
ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay)

This implements:

θEMA=αθEMA+(1α)θ\theta_{\text{EMA}} = \alpha \theta_{\text{EMA}} + (1 - \alpha)\theta


Step 3: Integrate into the training loop

1
2
3
4
5
6
7
8
for batch in dataloader:
loss = compute_loss(model, batch)

optimizer.zero_grad()
loss.backward()
optimizer.step()

ema_update(model, ema_model, decay=0.999)

EMA updates are performed after each optimizer step.

Step 4: Use EMA for evaluation / inference

1
2
ema_model.eval()
output = ema_model(x)

Use ema_model, not the raw model, when measuring validation metrics or running inference.

5) Why EMA Instead of Just a Smaller Learning Rate?

  • Small learning rate reduces motion (slower learning, less exploration).
  • EMA smooths motion without slowing down learning.

EMA lets you:

  • train fast (larger LR, better exploration)
  • evaluate stably (smoothed parameters)

So you get speed + stability.


6) Where EMA Is Essentially Mandatory

EMA is common or “expected” in:

  • Diffusion models (DDPM, Stable Diffusion, EDM, video diffusion)
  • GANs (e.g. StyleGAN uses EMA for the generator)
  • Large transformers and foundation models
  • Multimodal and long-horizon models

7) EMA vs SWA (Stochastic Weight Averaging)

Method Averaging style Key property
EMA Exponential Recent weights matter more
SWA Uniform All weights equally weighted

EMA adapts faster to learning dynamics (because newer weights have higher influence).


8) Choosing the Decay Hyperparameter

Typical values:

Model type Decay
GANs 0.999
Diffusion 0.9999
Transformers 0.999
Video models 0.9999

Higher decay → slower update → smoother but less reactive EMA.

A useful intuition is the “effective window length”:

effective window11α\text{effective window} \approx \frac{1}{1-\alpha}

So:

  • α=0.999\alpha=0.999 \Rightarrow window 1000\approx 1000 steps
  • α=0.9999\alpha=0.9999 \Rightarrow window 10000\approx 10000 steps

9) Deeper (Research-Level) Intuition

EMA can be interpreted as:

  • parameter trajectory averaging
  • variance reduction of gradient noise
  • an implicit regularizer
  • a bias toward flatter minima

It’s not just cosmetic smoothing — it can materially improve the geometry of the solution you end up evaluating.


10) Is EMA basically 2 copies of the weights? What is the actual overhead?

Yes. Literally yes.

You have:

  • Raw model weights → optimized by the optimizer
  • EMA model weights → a shadow copy updated by EMA

So memory-wise:

EMA ≈ 2× model parameters

No tricks, no compression.

(A) Memory overhead

  • Parameters are duplicated
  • No optimizer state for EMA
  • No gradients
  • No activations

So overhead is roughly:

Component Raw model EMA model
Parameters
Gradients
Optimizer states
Activations runtime only

Rule of thumb:
EMA adds ~+100% parameter memory, but much less than +100% total training memory.

For large models:

  • optimizer states (Adam = 2–3× params)
  • activations dominate memory anyway

So EMA is usually “cheap enough”.


(B) Compute overhead

EMA update cost per step:

1
2
for p, ema_p in zip(...):
ema_p = decay * ema_p + (1 - decay) * p

That’s:

  • one multiply
  • one add
  • per parameter

Compared to backprop:

EMA cost ≈ negligible (<1%)

11) Usage Cheatsheet

Stage / Purpose Model Used Why Notes
Training (forward / backward) Raw model Only raw weights are optimized via backprop EMA must never enter backprop
Optimizer step Raw model Updates trainable parameters EMA has no optimizer state
EMA update EMA model Running average of raw weights After optimizer.step()
Training loss logging Raw model Reflects actual optimization signal EMA loss may lag
Training accuracy (optional) EMA model Smoother, less noisy curves Log as auxiliary metric
Visualization / sampling EMA model Cleaner, more stable outputs Standard in diffusion / GANs
Validation EMA model Better generalization estimate Use EMA by default
Model selection (“best”) EMA model Less variance, more reliable Save best EMA checkpoint
Final evaluation / test EMA model Represents consolidated model Raw model often discarded
Deployment / inference EMA model Higher quality, more stable Especially critical for generative models

Checkpointing with EMA

What you want What to save Can resume training?
Continue training Raw + EMA + optimizer ✅ Yes
Only inference / release EMA only ❌ No
Best validation model EMA only ❌ No (but fine for eval)