Update: A simplified DDPM pytorch implementation is available on https://github.com/vinesmsuic/simple-DDPM-annotated.

What is Diffusion?

The essential idea is to systematically and slowly destroy structure in a data distribution through an iterative forward diffusion process which is fixed.

We then learn a reverse diffusion process that restores structure in data, yielding a highly flexible and tractable generative model of the data.

  • Does it look like a Autoencoder approach?

DALLE2, MidJourney, Disco Diffusion, Stable Diffusion, Imagen are build on top of diffusion model.

DALLE1 is a Auto-regressive model.

In terms of photo-realism outputs, Diffusion models are better than GANs.

  • Note: It is better if you understand VAE before studying Diffusion.

DDPM

Paper: Denoising Diffusion Probabilistic Models

The Annotated Diffusion Model

Lil’Log: What are Diffusion Models?

YouTube: Diffusion Models | Paper Explanation | Math Explained

YouTube: Diffusion models from scratch in PyTorch

YouTube: CVPR 2022 Tutorial on Denoising Diffusion-based Generative Modeling: Foundations and Applications

A (denoising) diffusion model is a neural network that learns to gradually denoise data starting from pure noise.

The set-up consists of 2 processes:

  • Forward (Fixed)
    • Forward diffusion process qq : Gradually (regulated by a schedule) adds noise (sampled from a normal distribution) to an image until become pure noise
    • different amount of noise are applied in each timestep according to the schedule (different mean and variance)
  • Reverse (Has to been Learned)
    • A Learned reverse denoising diffusion process pθp_{\theta} : a neural network is trained to gradually denoise (remove a step of noise in each pass.) an image from pure noise to actual image.

By doing so, we can start with a completely random noise and let the model remove noise until we have a new image.

It’s a markov chain because it’s a sequence of stochastic events where each timestep depends on the previous timesteps.

  • Note the latent states have the same dimensionality as the input image.

Training algorithm of DDPM

  • 1: x0q(x0)x_0 \sim q(x_0): take a random sample x0x_0 from the real unknown and possibiliy complex data distribution q(x0)q(x_0)
  • 2: tUniform({1,,T})t \sim \text{Uniform}(\{1,\dots,T\}) : sample a noise level tt uniformally between 11 and TT (i.e., a random time step)
  • 3: ϵN(0,I)\epsilon\sim\mathcal{N}(0,I) : sample some noise (has same dimensionality as input data) from a gaussian distribution and corrupt the input by this noise at timestep tt , using the reparameterization trick αˉtx0+1αˉtϵ\sqrt{\bar\alpha_t}x_0 + \sqrt{1-\bar\alpha_t}\epsilon
  • 4: give the generated sample to the neural network to predict the noise based on the corrupted image xtx_t and train the network
  • 5: 1-4 are done on batches of data and optimize the network.
img

Something Important:

  • Paper used T = 1000 (timestep = 1000), but the follow up papers are able to decrease this number.
  • Images are scaled to inbetween [-1, 1] as to have the same range as the prior of a standard normal distribution p(xT)N(0,1)p(x_T) \sim \mathcal{N}(0,1) centered at 0 with variance of 1.

Inference algorithm of DDPM (Denoising)

As mentioned, generating new images from a diffusion model happens by reversing the diffusion process: we start from TT, where we sample pure noise from a Gaussian distribution, and then use our neural network to gradually denoise it (using the conditional probability it has learned), until we end up at time step t=0t = 0. By predicting the noise in each denoising step, we can get the less noisy image xt1x_{t-1} by predicting the mean and variance of the noise.

Ideally, we end up with an image that looks like it came from the real data distribution.

  • 1: xTN(0,I)x_T \sim \mathcal{N}(0,I): Draw samples from normal distribution N(0,I)\mathcal{N}(0,I)

  • 2: for t=T,,1t = T, \dots, 1: for reverse timestep T to 1, at every step

  • 3: draw white noise zz from normal distribution

    • If t=1t = 1, we need not to add noise anymore.
  • 4: Forming new sample with the mean of denoising model 1αt(xt1αt1αˉtϵθ(xt,t))\frac{1}{\sqrt{\alpha_t}}(x_t - \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}}\epsilon_\theta(x_t,t)) and then add the white noise zz rescaled with standard deviation σt\sigma_t.

img

DDPM From Math Perspective

Both the forward and reverse process are indexed by tt happen for some number of finite time steps TT.

  • Start with t=0t=0, sample a real image x0x_0 from image distribution and apply some noise from a Gaussian distribution at each time step tt, which is added to the image of the previous time step t1t-1.
    • x1x_1 will be 1st iteration of noise applied, x42x_{42} will be 42th iteration of noise applied… The last step will be xtx_t.
  • Given a large enough TT, a well behaved schedule for adding noise at each time step, we get an isotropic Gaussian distribution at t=Tt = T via a gradual process. (Isotropic means it looks the same in every direction)

img

We define function for Forward diffusion process q(xtxt1)q(\mathrm{x}_t | x_{t-1}) which means given an image with less noise at timestep t1t-1, we derive the image with little bit more noise at timestep tt.

Then we also define function for reverse denoising diffusion process pθ(xt1xt)p_\theta(\mathrm{x}_{t-1}|x_t) which means given an image with more noise at timestep tt, we derive the image with less noise at timestep t1t-1. it is done by predicting the noise that was added to the image.

Forward Diffusion Process

From x0x_0 input image to a noisy version image at timestep tt, it can be formulated as:

q(x1:Tx0)=t=1Tq(xtxt1)\mathrm{q}\left(\mathrm{x}_{1:T} \mid x_{0}\right) = \prod^T_{t=1}\mathrm{q}\left(\mathrm{x}_{t} \mid x_{t-1}\right)

The single forward diffusion process q(xtxt1)q(\mathrm{x}_t | x_{t-1}) can be formulated as N(z;μ,σ)\mathcal{N}(z; \mu, \sigma):

q(xtxt1)=N(xt;1βtxt1,βtI)\mathrm{q}\left(x_{t} \mid x_{t-1}\right)=\mathcal{N}\left(x_{t}; \sqrt{1-\beta_{t}} x_{t-1}, \beta_{t} I\right)

where

  • N\mathcal{N} is the normal distribution
  • xtx_t is the output
  • 1βtxt1\sqrt{1-\beta_{t}}x_{t-1} is the mean
  • βtI\beta_{t} I is the variance
  • β\beta is the scale at schedule

This means the sample xtx_t is obtained by scaling the previous sample xt1x_{t-1} with 1βt\sqrt{1-\beta_t} according to a variance schedule, then add independent and identically distributed Gaussian noise with square root of variance βt\beta_t at timestep tt.

xt=1βtxt1+βt ϵx_t = \sqrt{1-\beta_t}x_{t-1} + \sqrt{\beta_t}\space \epsilon

where ϵN(0,I)\epsilon \sim \mathcal{N}(0, I) => sample from gaussian noise

DDPM used the linear schedule such that it will looks like this:

1
2
3
4
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)

img

  • In later papers, Cosine schedule is used to replace linear schedule
    • Linear approach is sub-optimal because the last couple of timesteps already seems like complete noise and might be redundent. On the other hand, the information is destroyed too fast.
    • Cosine schedule solves both problems of the linear schedule

img

Since sum of gaussians is still a gaussian distribution, To apply multiple steps in one step, we can define:

  • αt=1βt\alpha_t = 1 - \beta_t
  • αˉt=s=1tαs\bar\alpha_t = \prod^t_{s=1} \alpha_s
  • Therefore at t=4t = 4, α4=α1α2α3α4\alpha_4 = \alpha_1 \cdot \alpha_2 \cdot \alpha_3 \cdot \alpha_4

Using the Reparameterization Trick N(μ,σ2)μ+σϵ\mathcal{N}(\mu, \sigma^2) \rightarrow \mu+\sigma \cdot \epsilon,

We get q(xtxt1)=N(xt;1βtxt1,βtI)=1βtxt1+βtϵt1\mathrm{q}\left(\mathrm{x}_{t} \mid x_{t-1}\right)=\mathcal{N}\left(x_{t}; \sqrt{1-\beta_{t}} x_{t-1}, \beta_{t} I\right) = \sqrt{1-\beta_{t}}x_{t-1} + \sqrt{\beta_t}\epsilon_{t-1}

Sub αt=1βt\alpha_t = 1 - \beta_t into the formula, we get

xt=1βtxt1+βtϵ=αtxt1+1αtϵt1x_t = \sqrt{1-\beta_{t}}x_{t-1} + \sqrt{\beta_t}\epsilon = \sqrt{\alpha_t} x_{t-1} + \sqrt{1-\alpha_t}\epsilon_{t-1}.

Given xt1=αt1xt2+1αt1ϵt2x_{t-1} = \sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1-\alpha_{t-1}}\epsilon_{t-2}

We can represent xtx_{t} as αtαt1xt2+1αtαt1ϵˉt2\sqrt{\alpha_t\alpha_{t-1}} x_{t-2} + \sqrt{1-\alpha_t\alpha_{t-1}}\bar\epsilon_{t-2} , where ϵˉt2\bar\epsilon_{t-2} merge the two gaussians.

https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#forward-diffusion-process

When we merge two Gaussians with different variance, N(0,σ12I)\mathcal{N}(0, \sigma^2_1 I) and N(0,σ22I)\mathcal{N}(0, \sigma^2_2 I),

the new distribution is N(0,(σ12+σ22)I)\mathcal{N}(0, (\sigma^2_1+\sigma^2_2) I) . With this property we can compute the merged standard deviation:

xt=αtxt1+1αtϵt1=αt(αt1xt2+1αt1ϵt2)+1αtϵt1x_t = \sqrt{\alpha_t} x_{t-1} + \sqrt{1-\alpha_t}\epsilon_{t-1} = \sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1-\alpha_{t-1}}\epsilon_{t-2}) + \sqrt{1-\alpha_t}\epsilon_{t-1}

αt(αt1xt2+1αt1ϵt2)+1αtϵt1=αtαt1xt2+((αt)(1αt1)ϵt2+1αtϵt1)\sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}x_{t-2} + \sqrt{1-\alpha_{t-1}}\epsilon_{t-2}) + \sqrt{1-\alpha_t}\epsilon_{t-1} =\sqrt{\alpha_t\alpha_{t-1}} x_{t-2} + (\sqrt{(\alpha_t)(1-\alpha_{t-1})}\epsilon_{t-2} + \sqrt{1-\alpha_t}\epsilon_{t-1} )

Forming new distribution N(0,(σ12+σ22)I)\mathcal{N}(0, (\sigma^2_1+\sigma^2_2) I),

((αt)(1αt1)ϵt2+1αtϵt1)=(αt)(1αt1)2+1αt2ϵˉt2(\sqrt{(\alpha_t)(1-\alpha_{t-1})}\epsilon_{t-2} + \sqrt{1-\alpha_t}\epsilon_{t-1} )= \sqrt{\sqrt{(\alpha_t)(1-\alpha_{t-1})}^2+\sqrt{1-\alpha_t}^2} \bar\epsilon_{t-2}

=at(1at1)+1atϵˉt2=atatat1+1atϵˉt2= \sqrt{a_t (1-a_{t-1}) + 1-a_t} \bar\epsilon_{t-2} = \sqrt{\cancel{a_t}- a_ta_{t-1} + 1 \cancel{-a_t}} \bar\epsilon_{t-2}

=1αtαt1ϵˉt2= \sqrt{1-\alpha_t\alpha_{t-1}}\bar\epsilon_{t-2}

Therefore we can represent xtx_{t} as αtαt1xt2+1αtαt1ϵˉt2\sqrt{\alpha_t\alpha_{t-1}} x_{t-2} + \sqrt{1-\alpha_t\alpha_{t-1}}\bar\epsilon_{t-2} , where ϵˉt2\bar\epsilon_{t-2} merge the two gaussians.

With this logic, we can represent xtx_t as αtαt1αt2xt3+1αtαt1αt2ϵˉt3\sqrt{\alpha_t\alpha_{t-1}\alpha_{t-2}} x_{t-3} + \sqrt{1-\alpha_t\alpha_{t-1}\alpha_{t-2}}\bar\epsilon_{t-3}

Therefore, to sum up,

αtαt1α1α0x0+1αtαt1α1α0ϵˉ=αˉtx0+1αˉtϵˉ\sqrt{\alpha_{t} \alpha_{t-1} \ldots \alpha_{1} \alpha_{0}} x_{0}+\sqrt{1-\alpha_{t} \alpha_{t-1} \ldots \alpha_{1} \alpha_{0}} \bar\epsilon = \sqrt{\bar{\alpha}_t x_0} + \sqrt{1-\bar{\alpha}_t}\bar\epsilon

and rewrite as:

q(xtx0)=N(xt;αˉtx0,(1αˉt)I)\mathrm{q}\left(x_{t} \mid x_{0}\right)=\mathcal{N}\left(x_{t}; \sqrt{\bar\alpha_t} x_{0}, (1-\bar\alpha_t) I\right)

So to do sampling we can simply do

xt=αˉtx0+(1αˉt)ϵx_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{(1-\bar{\alpha}_t)}\epsilon

where ϵN(0,I)\epsilon \sim \mathcal{N}(0, I) => sample from gaussian noise

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Define beta schedule
T = 250
betas = linear_beta_schedule(timesteps=T)

# calculate alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)


alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Get the indexed term from the list
# https://github.com/pytorch/pytorch/issues/15245 Gather backward is faster than integer indexing on GPU
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# forward diffusion (using the nice property alpha)
def q_sample(x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)

# \sqrt{\bar\alpha_t}
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)

# \sqrt{(1-\bar\alpha_t)}
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x_start.shape
)

# \mathcal{N}\left(x_{t}; \sqrt{\bar\alpha_t} x_{0}, (1-\bar\alpha_t) I\right)
# N(mean, var) * (1-alpha_cumprod) = N(mean, (1-alpha_cumprod) * var)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

So the forward process looks like this:

img

But what is really happening?

img

  • At small t, most of the low frequency contents are not perturbed by the noise, but high frequency content are being perturbed.
  • At bigger t, low frequency contents are also perturbed.
  • At the end of forward process, we get rid of the both low and high frequency contents of image.

Parametrized Reverse Denoising Diffusion Process

To reverse the process, the intuitive idea is to find q(xt1xt)\mathrm{q}\left(x_{t-1} \mid x_{t}\right).

However, q(xt1xt)q(xt1)q(xtxt1)q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right) \propto q\left(\mathbf{x}_{t-1}\right) q\left(\mathbf{x}_t \mid \mathbf{x}_{t-1}\right) is intractable. Instead, we can approximate q(xt1xt)\mathrm{q}\left(x_{t-1} \mid x_{t}\right) using a Normal distribution if βt\beta_t is small in each forward diffusion step.

Generating new images from a diffusion model happens by reversing the diffusion process: we start from TT, where we sample pure noise from a Gaussian distribution, and then use our neural network to gradually denoise it (using the conditional probability it has learned), until we end up at time step t=0t = 0. Ideally, we end up with an image that looks like it came from the real data distribution.

From pure noise / noisy image xtx_t to original image x0x_0 or less noisy image at timestep tt, it can be formulated as:

pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)p_{\theta}(x_{0:T})=p(x_T)\prod^T_{t=1}p_{\theta}(x_{t-1}|x_t)

p(xT)=N(xT;0,I)p(x_T) = \mathcal{N}(x_T; 0, I)

The single reverse denoising diffusion process p(xt1xt)p(\mathrm{x}_{t-1}|x_t) can be formulated as N(z;μ,σ)\mathcal{N}(z; \mu, \sigma):

pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))p_\theta\left(x_{t-1} \mid x_{t}\right)=\mathcal{N}\left(x_{t-1} ; \mu_{\theta}\left(x_{t}, t\right), \Sigma_{\theta}\left(x_{t}, t\right)\right)

where

  • N\mathcal{N} is the normal distribution
  • μθ\mu_{\theta} parametrize the mean
  • θ\sum_\theta parameterize the variance

We can derive a slightly less denoised image xt1\mathbf{x}_{t-1 } by plugging in the reparametrization of the mean and variance.

We need a trainable neural network to represent a (conditional) probability distribution of the backward process. We want to learn 2 parameters:

  • a mean parametrized by μθ(xt,t)\mu_{\theta}\left(x_{t}, t\right)
  • a variance parameterized by Σθ(xt,t)\Sigma_{\theta}\left(x_{t}, t\right)

However, DDPM authors decided to keep the variance fixed, and let the neural network only learn (represent) the mean μθ\mu_\theta of the conditional probability distribution.

p(xt1xt)=N(xt1;μθ(xt,t),βtI)p\left(x_{t-1} \mid x_{t}\right)=\mathcal{N}\left(x_{t-1} ; \mu_{\theta}\left(x_{t}, t\right), \beta_t I\right)

where a linear schedule is used

Later in the Improved diffusion models paper, a neural network also learns the variance of this backward process, besides the mean.

By predicting the mean of noise, we can know the Noise of the image.

In order to get the exact image, we simply get xt1xtnoisex_{t-1} \approx x_t - \text{noise}.

The predicted mean of noise:

μθ(xt,t)=μ~t(xt,1αt(xt1αtϵθ(xt)))\mu_{\theta}(x_{t}, t) = \tilde{\mu}_{t}(x_t, \frac{1}{\sqrt{\alpha_t}}(x_t - {\sqrt{1-\overline{\alpha}_t}}\epsilon_{\theta}(x_t)))

=1αt(xtβt1αtϵθ(xt,t))=1αt(xtβtϵθ(xt,t)1αt)= \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\overline{\alpha}_t}}\epsilon_{\theta}(x_t, t)) = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t \epsilon_{\theta}(x_t, t)}{\sqrt{1-\overline{\alpha}_t}})

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)

# calculate alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)


alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

# Get the indexed term from the list
# https://github.com/pytorch/pytorch/issues/15245 Gather backward is faster than integer indexing on GPU
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
@torch.no_grad()
def p_sample(model, x, t, t_index):
"""
Sample from the model. Mean is predicted, Variance is fixed in this example
"""
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)

# Use our model (noise predictor) to predict the mean
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
)

if t_index == 0:
return model_mean
else:
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = torch.randn_like(x)
# x_{t-1} sample is generated
image = model_mean + torch.sqrt(posterior_variance_t) * noise
return image

To denoise from timestep TT to 0 and get a clear image, we iteratively do the denoising step.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@torch.no_grad()
def p_sample_loop(model, shape):
device = next(model.parameters()).device

b = shape[0]
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []

for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
imgs.append(img.cpu().numpy())
return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

The above code return the image output in all the denoising steps. Alternatively we can modify the code to send only the last step.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
@torch.no_grad()
def p_sample_loop_last(model, shape):
device = next(model.parameters()).device

b = shape[0]
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []

final_img = None
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
final_img = img.cpu().numpy()
return final_img

Recall what we understood in the forward process. In large timesteps, the low frequency components are hidden, and in less timestep, the high-frequency contents are hidden.

Therefore, in the reverse process, we can make a trade-off between content detail with the weighting.

  • Low frequency content responses to the main content of the image
  • High-frequency content responses to the low-level fine details

Therefore the noise schedule can play a huge role here.

Objective function of Diffusion

The loss function is simply the negative log likelihood

log(pθ(x0))-\log(p_\theta(x_0))

However, pθ(x0)p_\theta(x_0) is not nicely computable as it depend all other timesteps coming before x0x_0 (i.e. xT,...,x1x_T, ..., x_1)

As a solution we can compute the variational lower bound, that is commonly used for training Variational Autoencoder. We write this formula:

log(pθ(x0))log(pθ(x0))+DKL(q(x1:Tx0)pθ(x1:Tx0))-\log(p_\theta(x_0)) \leq -\log(p_\theta(x_0)) + D_{KL}(q(x_{1:T}|x_0)||p_\theta(x_{1:T}|x_0))

  • Note the KL divergence DKL(q(x1:Tx0)pθ(x1:Tx0))D_{KL}(q(x_{1:T}|x_0)||p_\theta(x_{1:T}|x_0)) is something non-negative.

But its still not computable as the log(pθ(x0))-\log(p_\theta(x_0)) still exists; so we need to further rewrite the formula:

We can rewrite the KL divergence such that

DKL(q(x1:Tx0)pθ(x1:Tx0))=log(q(x1:T)x0pθ(x1:Tx0))D_{KL}(q(x_{1:T}|x_0)||p_\theta(x_{1:T}|x_0)) = \log(\frac{q(x_{1:T})|x_0}{p_\theta(x_{1:T}|x_0)})

Then bayesian rule:

log(q(x1:T)x0pθ(x1:Tx0))=pθ(x0x1:T)pθ(x1:T)pθ(x0)\log(\frac{q(x_{1:T})|x_0}{p_\theta(x_{1:T}|x_0)}) = \frac{p_\theta(x_0|x_{1:T})p_\theta(x_{1:T})}{p_\theta(x_0)}

pθ(x0x1:T)pθ(x1:T)pθ(x0)=pθ(x0,x1:T)pθ(x0)=pθ(x0:T)pθ(x0)\frac{p_\theta(x_0|x_{1:T})p_\theta(x_{1:T})}{p_\theta(x_0)} = \frac{p_\theta(x_0,x_{1:T})}{p_\theta(x_0)} = \frac{p_\theta(x_{0:T})}{p_\theta(x_0)}

Turn it into log and move it around

pθ(x0:T)pθ(x0)=log(q(x1:Tx0)pθ(x0:T)pθ(x0))=log(q(x1:Tx0)pθ(x0)pθ(x0:T))\frac{p_\theta(x_{0:T})}{p_\theta(x_0)} = \log(\frac{q(x_{1:T}|x_0)}{\frac{p_\theta(x_{0:T})}{p_\theta(x_0)}}) = \log(\frac{q(x_{1:T}|x_0)p_\theta(x_0)}{p_\theta(x_{0:T})})

log(q(x1:Tx0)pθ(x0)pθ(x0:T))=log(q(x1:Tx0)pθ(x0:T))+log(pθ(x0))\log(\frac{q(x_{1:T}|x_0)p_\theta(x_0)}{p_\theta(x_{0:T})}) = \log(\frac{q(x_{1:T}|x_0)}{p_\theta(x_{0:T})}) + \log(p_\theta(x_0))

So we get

log(pθ(x0))log(pθ(x0))+log(q(x1:Tx0)pθ(x0:T))+log(pθ(x0))-\log(p_\theta(x_0)) \leq -\log(p_\theta(x_0)) + \log(\frac{q(x_{1:T}|x_0)}{p_\theta(x_{0:T})}) + \log(p_\theta(x_0))

which the two log(pθ(x0))\log(p_\theta(x_0)) terms can be cancelled each other and become the variational lower bound that we want to minimize.

log(pθ(x0))log(q(x1:Tx0)pθ(x0:T))-\log(p_\theta(x_0)) \leq \log(\frac{q(x_{1:T}|x_0)}{p_\theta(x_{0:T})})

where

  • q(x1:Tx0)q(x_{1:T}|x_0) is the forward process (Given initial image x0x_0 find x1:Tx_{1:T})
    • which can be reformulated as t=1Tq(xtxt1)\prod^T_{t=1}\mathrm{q}\left(\mathrm{x}_{t} \mid x_{t-1}\right)
  • pθ(x0:T)p_\theta(x_{0:T}) is the reverse process (From noisy image to initial image)
    • which can be reformulated as p(xT)t=1Tpθ(xt1xt)p(x_T)\prod^T_{t=1}p_{\theta}(x_{t-1}|x_t)

log(pθ(x0))log(t=1Tq(xtxt1)p(xT)t=1Tpθ(xt1xt))-\log(p_\theta(x_0)) \leq \log(\frac{\prod^T_{t=1}\mathrm{q}\left(\mathrm{x}_{t} \mid x_{t-1}\right)}{p(x_T)\prod^T_{t=1}p_{\theta}(x_{t-1}|x_t)})

we can extract the computable term p(xT)p(x_T) from log(t=1Tq(xtxt1)p(xT)t=1Tpθ(xt1xt))\log(\frac{\prod^T_{t=1}\mathrm{q}\left(\mathrm{x}_{t} \mid x_{t-1}\right)}{p(x_T)\prod^T_{t=1}p_{\theta}(x_{t-1}|x_t)})

log(t=1Tq(xtxt1)p(xT)t=1Tpθ(xt1xt))=log(t=1Tq(xtxt1)t=1Tpθ(xt1xt))log(p(xT))\log(\frac{\prod^T_{t=1}\mathrm{q}\left(\mathrm{x}_{t} \mid x_{t-1}\right)}{p(x_T)\prod^T_{t=1}p_{\theta}(x_{t-1}|x_t)}) = \log(\frac{\prod^T_{t=1}\mathrm{q}\left(\mathrm{x}_{t} \mid x_{t-1}\right)}{\prod^T_{t=1}p_{\theta}(x_{t-1}|x_t)}) - \log(p(x_T))

Bring out the product term to become the sum (log rules)

log(t=1Tq(xtxt1)t=1Tpθ(xt1xt))log(p(xT))=log(p(xT))+t=1Tlog(q(xtxt1)pθ(xt1xt))\log(\frac{\prod^T_{t=1}\mathrm{q}\left(\mathrm{x}_{t} \mid x_{t-1}\right)}{\prod^T_{t=1}p_{\theta}(x_{t-1}|x_t)}) - \log(p(x_T)) = - \log(p(x_T)) + \sum^T_{t=1}\log(\frac{q(x_t|x_{t-1})}{p_\theta(x_{t-1}|x_t)})

The author did a little trick to move the parametrized term. First split up the first term of the summation:

log(p(xT))+t=1Tlog(q(xtxt1)pθ(xt1xt))=log(p(xT))+t=2Tlog(q(xtxt1)pθ(xt1xt))+log(q(x1x0)pθ(x0x1))- \log(p(x_T)) + \sum^T_{t=1}\log(\frac{q(x_t|x_{t-1})}{p_\theta(x_{t-1}|x_t)})=- \log(p(x_T)) + \sum^T_{t=2}\log(\frac{q(x_t|x_{t-1})}{p_\theta(x_{t-1}|x_t)}) + \log(\frac{q(x_1|x_{0})}{p_\theta(x_{0}|x_1)})

By the bayes rule q(xtxt1)=q(xt1xt)q(xt)q(xt1)q(x_t|x_{t-1}) = \frac{q(x_{t-1}|x_{t})q(x_t)}{q(x_{t-1})} , we get the three terms with high variance (We dont know where the noise image came from). Therefore we could reduce the variance by conditioning x0x_0 (Given the initial noise-free picture) such that q(xt1xt)q(xt)q(xt1)q(xt1xt,x0)q(xtx0)q(xt1x0)\frac{q(x_{t-1}|x_{t})q(x_t)}{q(x_{t-1})} \Rightarrow \frac{q(x_{t-1}|x_{t}, x_0)q(x_t|x_0)}{q(x_{t-1}|x_0)}

q(xtxt1)=q(xt1xt)q(xt)q(xt1)q(xt1xt,x0)q(xtx0)q(xt1x0)q(x_t|x_{t-1}) = \frac{q(x_{t-1}|x_{t})q(x_t)}{q(x_{t-1})} \Rightarrow \frac{q(x_{t-1}|x_{t}, x_0)q(x_t|x_0)}{q(x_{t-1}|x_0)}

Plugging this term q(xtxt1)=q(xt1xt,x0)q(xtx0)q(xt1x0)q(x_t|x_{t-1}) = \frac{q(x_{t-1}|x_{t}, x_0)q(x_t|x_0)}{q(x_{t-1}|x_0)} into log(p(xT))+t=2Tlog(q(xtxt1)pθ(xt1xt))+log(q(x1x0)pθ(x0x1))- \log(p(x_T)) + \sum^T_{t=2}\log(\frac{q(x_t|x_{t-1})}{p_\theta(x_{t-1}|x_t)}) + \log(\frac{q(x_1|x_{0})}{p_\theta(x_{0}|x_1)}), We have

log(p(xT))+t=2Tlog(q(xt1xt,x0)q(xtx0)pθ(xt1xt)q(xt1x0))+log(q(x1x0)pθ(x0x1))- \log(p(x_T)) + \sum^T_{t=2}\log(\frac{q(x_{t-1}|x_{t}, x_0)q(x_t|x_0)}{p_\theta(x_{t-1}|x_t)q(x_{t-1}|x_0)}) + \log(\frac{q(x_1|x_{0})}{p_\theta(x_{0}|x_1)})

Split up the summation part

=log(p(xT))+t=2Tlog(q(xt1xt,x0)pθ(xt1xt))+t=2Tlog(q(xtx0)q(xt1x0))+log(q(x1x0)pθ(x0x1))= - \log(p(x_T)) + \sum^T_{t=2}\log(\frac{q(x_{t-1}|x_{t}, x_0)}{p_\theta(x_{t-1}|x_t)}) + \sum^T_{t=2}\log(\frac{q(x_t|x_0)}{q(x_{t-1}|x_0)}) + \log(\frac{q(x_1|x_{0})}{p_\theta(x_{0}|x_1)})

The term t=2Tlog(q(xtx0)q(xt1x0))\sum^T_{t=2}\log(\frac{q(x_t|x_0)}{q(x_{t-1}|x_0)}) can be simplified as

t=2Tlog(q(xtx0)q(xt1x0))=log(t=2Tq(xtx0)q(xt1x0))=log(q(xTx0)q(x1x0))\sum^T_{t=2}\log(\frac{q(x_t|x_0)}{q(x_{t-1}|x_0)}) = \log(\prod^T_{t=2}\frac{q(x_t|x_0)}{q(x_{t-1}|x_0)}) = \log(\frac{q(x_T|x_0)}{q(x_{1}|x_0)})

Because e.g. Let T=5T = 5:

t=25log(q(xtx0)q(xt1x0))=log(t=25q(xtx0)q(xt1x0))=log(q(x2x0)q(x3x0)q(x4x0)q(x5x0)q(x1x0)q(x2x0)q(x3x0)q(x4x0))=log(q(x5x0)q(x1x0))\sum^5_{t=2}\log(\frac{q(x_t|x_0)}{q(x_{t-1}|x_0)}) = \log(\prod^5_{t=2}\frac{q(x_t|x_0)}{q(x_{t-1}|x_0)}) = \log(\frac{\cancel{q(x_2|x_0)q(x_3|x_0)q(x_4|x_0)}q(x_5|x_0)}{q(x_{1}|x_0)\cancel{q(x_{2}|x_0)q(x_{3}|x_0)q(x_{4}|x_0)}}) = \log(\frac{q(x_5|x_0)}{q(x_{1}|x_0)})

So we get

log(p(xT))+t=2Tlog(q(xt1xt,x0)pθ(xt1xt))+log(q(xTx0)q(x1x0))+log(q(x1x0)pθ(x0x1))- \log(p(x_T)) + \sum^T_{t=2}\log(\frac{q(x_{t-1}|x_{t}, x_0)}{p_\theta(x_{t-1}|x_t)}) + \log(\frac{q(x_T|x_0)}{q(x_{1}|x_0)}) + \log(\frac{q(x_1|x_{0})}{p_\theta(x_{0}|x_1)})

extract the terms using log rules and found we can cancel some terms.

=log(p(xT))+t=2Tlog(q(xt1xt,x0)pθ(xt1xt))+log(q(xTx0))log(q(x1x0))+log(q(x1x0))log(pθ(x0x1))= - \log(p(x_T)) + \sum^T_{t=2}\log(\frac{q(x_{t-1}|x_{t}, x_0)}{p_\theta(x_{t-1}|x_t)}) + \log(q(x_T|x_0)) \cancel{-\log(q(x_{1}|x_0)) + \log(q(x_1|x_{0}))} - \log(p_\theta(x_{0}|x_1))

=log(p(xT))+t=2Tlog(q(xt1xt,x0)pθ(xt1xt))+log(q(xTx0))log(pθ(x0x1))= - \log(p(x_T)) + \sum^T_{t=2}\log(\frac{q(x_{t-1}|x_{t}, x_0)}{p_\theta(x_{t-1}|x_t)}) + \log(q(x_T|x_0)) - \log(p_\theta(x_{0}|x_1))

Fuse the term log(p(xT))- \log(p(x_T)) and +log(q(xTx0))+\log(q(x_T|x_0)) using log rule to form log(q(xTx0)p(xT))\log{(\frac{q(x_T|x_0)}{p(x_T)})}

=log(q(xTx0)p(xT))+t=2Tlog(q(xt1xt,x0)pθ(xt1xt))log(pθ(x0x1))= \log{(\frac{q(x_T|x_0)}{p(x_T)})} + \sum^T_{t=2}\log(\frac{q(x_{t-1}|x_{t}, x_0)}{p_\theta(x_{t-1}|x_t)}) - \log(p_\theta(x_{0}|x_1))

Convert log(q(xTx0)p(xT))\log{(\frac{q(x_T|x_0)}{p(x_T)})} and t=2Tlog(q(xt1xt,x0)pθ(xt1xt))\sum^T_{t=2}\log(\frac{q(x_{t-1}|x_{t}, x_0)}{p_\theta(x_{t-1}|x_t)}) into KL Divergence terms, our objective is:

L=Eq[DKL(q(xTx0)p(xT))LT+t>1DKL(q(xt1xt,x0)pθ(xt1xt))Lt1logpθ(x0x1))L0]L=\mathbb{E}_q[\underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_T \mid \mathbf{x}_0\right) \| p\left(\mathbf{x}_T\right)\right)}_{L_T}+\sum_{t>1} \underbrace{D_{\mathrm{KL}}\left(q\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0\right) \| p_\theta\left(\mathbf{x}_{t-1} \mid \mathbf{x}_t\right)\right)}_{L_{t-1}} \underbrace{\left.-\log p_\theta\left(\mathbf{x}_0 \mid \mathbf{x}_1\right)\right)}_{L_0}]

DKL(q(xTx0)p(xT))+t=2TDKL(q(xt1xt,x0)pθ(xt1xt))log(pθ(x0x1))D_{KL}(q(x_T|x_0)||p(x_T)) + \sum^T_{t=2} D_{KL} (q(x_{t-1}|x_{t}, x_0) || p_\theta(x_{t-1}|x_t)) - \log(p_{\theta}(x_0|x_1))

Note the first term DKL(q(xTx0)p(xT))D_{KL}(q(x_T|x_0)||p(x_T)) can be dropped because it has no learnable parameters, and the term will be small. It is simply the KL divergence from the diffusion kernel in the last step (xTx0)(x_T|x_0) to the base distribution xTx_T. The (xTx0)(x_T|x_0) is converge to a standard normal distribution, which is the same distribution as xTx_T. Therefore after dropping the first term we have:

t=2TDKL(q(xt1xt,x0)pθ(xt1xt))log(pθ(x0x1))\sum^T_{t=2} D_{KL} (q(x_{t-1}|x_{t}, x_0) || p_\theta(x_{t-1}|x_t)) - \log(p_{\theta}(x_0|x_1))

Given p(xt1xt)=N(xt1;μθ(xt,t),βI)p\left(x_{t-1} \mid x_{t}\right)=\mathcal{N}\left(x_{t-1} ; \mu_{\theta}\left(x_{t}, t\right), \beta I\right) <= In DDPM the variance is fixed, and

q(xt1xt,x0)=N(xt1,μ~t(xt,x0),β~tI)q(x_{t-1}|x_{t}, x_0) = \mathcal{N}(x_{t-1}, \tilde{\mu}_t(x_t, x_0), \tilde{\beta}_tI) <= We skip the explaination but q(xt1xt,x0)q(x_{t-1}|x_{t}, x_0) is the tractable posterior distribution.

Note

  • μ~t(xt,x0)=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉtx0\tilde{\mu}_t(x_t, x_0) = \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0
  • β~tI=1αˉt11αˉtβt\tilde{\beta}_tI = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_{t}} \cdot \beta_t

Recall: to apply multiple step into one step,

xt=αˉtx0+1αˉtϵx_t = \sqrt{\bar{\alpha}_t x_0} + \sqrt{1-\bar{\alpha}_t}\epsilon

Therefore

x0=1αˉt(xt1αˉtϵ)x_0 = \frac{1}{\sqrt{\bar{\alpha}}_t}(x_t - \sqrt{1-\bar{\alpha}_t}\epsilon)

Plugging x0=1αˉt(xt1αˉtϵ)x_0 = \frac{1}{\sqrt{\bar{\alpha}}_t}(x_t - \sqrt{1-\bar{\alpha}_t}\epsilon) into μ~t(xt,x0)=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉtx0\tilde{\mu}_t(x_t, x_0) = \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}x_0 , we get:

μ~t(xt,x0)=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉt1αˉt(xt1αˉtϵ)\tilde{\mu}_t(x_t, x_0) = \frac{\sqrt{\alpha_t}(1-\bar{\alpha}_{t-1})}{1-\bar{\alpha}_t}x_t + \frac{\sqrt{\bar{\alpha}_{t-1}}\beta_t}{1-\bar{\alpha}_t}\frac{1}{\sqrt{\bar{\alpha}}_t}(x_t - \sqrt{1-\bar{\alpha}_t}\epsilon)

And end up getting the formula as the predicted noise.

μ~t(xt,x0)=1αt(xtβt1αˉtϵ)\tilde{\mu}_t(x_t, x_0) = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon)

Therfore we can predict the noise with the neural network:

μθ(xt,t)=1αt(xtβt1αtϵθ(xt,t))\mu_{\theta}(x_{t}, t) = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\overline{\alpha}_t}} \epsilon_{\theta}(x_t, t))

The DDPM author declared the MSE loss between the actual noise and the predicted noise.

Lt1=12σt2μ~t(xt,x0)μθ(xt,t)2\mathcal{L}_{t-1} = \frac{1}{2\sigma^2_t}||\tilde{\mu}_t(x_t, x_0) - \mu_{\theta}(x_{t}, t)||^2

Lt1=12σt21αt(xtβt1αˉtϵ)1αt(xtβt1αˉtϵθ(xt,t))2\mathcal{L}_{t-1} = \frac{1}{2\sigma^2_t}||\frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon) - \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_{\theta}(x_t,t))||^2

Which are completely the same except the epsilon term. Through simplification, we get

Lt1=βt22σt2αt(1α^t)ϵϵθ(xt,t)2\mathcal{L}_{t-1} = \frac{\beta_t^2}{2\sigma^2_t \alpha_t(1-\hat{\alpha}_t)}||\epsilon - \epsilon_\theta(x_t, t)||^2

we can replace the scalar term with a time dependant lambda value:

Lt1=λtϵϵθ(xt,t)2\mathcal{L}_{t-1} = \lambda_t||\epsilon - \epsilon_\theta(x_t, t)||^2

where the time dependant λt\lambda_t ensures that the training objective is weighted properly for the maximum data likelihood training. However, this weight is often very large for small tt 's, and very small for large tt’s.

… And then the author found out ignore the scaling term (put λt=1\lambda_t = 1) would result a better quality. Therefore we get:

Lt1=ϵϵθ(xt,t)2\mathcal{L}_{t-1} = ||\epsilon - \epsilon_\theta(x_t, t)||^2

Replugging the term into t=2TDKL(q(xt1xt,x0)\sum^T_{t=2} D_{KL} (q(x_{t-1}|x_{t}, x_0):

LVLB=t=2Tϵϵθ(xt,t)2log(pθ(x0x1))\mathcal{L}_{VLB} = \sum^T_{t=2} ||\epsilon - \epsilon_\theta(x_t, t)||^2 - \log(p_{\theta}(x_0|x_1))

As at sampling time t=1t=1 we dont noise to it, we can drop the log(pθ(x0x1)- \log(p_{\theta}(x_0|x_1) term as well So Finally:

Lsimple=Et,x0,ϵ[ϵϵθ(xt,t)2]\mathcal{L}_{simple} = \mathbb{E}_{t,x_0,\epsilon}[||\epsilon - \epsilon_\theta(x_t, t)||^2]

So that’s way basically in implementation, we want to compute a MSE loss (or any loss) between the real noise and the predicted noise.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start)

x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
predicted_noise = denoise_model(x_noisy, t)

if loss_type == 'l1':
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == 'l2':
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()

return loss

Components of a Diffusion model

We will need mainly 3 components:

  • A UNet model that predicts the noise in an image
  • Noise Scheduler that sequentially adds noise
  • A way to encode the current timestep

Generally we want to use a network that is similar to Autoencoder. we want to have “bottleneck” layer in between the encoder and decoder. The encoder first encodes an image into a smaller hidden representation called the “bottleneck”, and the decoder then decodes that hidden representation back into an actual image. This forces the network to only keep the most important information in the bottleneck layer.

  • DDPM authors used a U-Net, similar to an unmasked PixelCNN++ with group normalization throughout
    • bottleneck, residual connections between encoder and decoder (greatly improving gradient flow)
    • Attention, ConvNext
    • Sinusoidal embedding from Transformer is used to project into each residual block because to create a denoising schedule that match the noising schedule in the forward process
    • UNet is a segmentation network that gives an output dimension same as input dimension.
  • The model take a noisy image with 3 color channels as inputs, and predict the noise in the image
    • That means the model learns the mean (and variance) of the gaussian distribution of the images
      • Known as denoising score matching

img

  • Timestep encoding is used to encode the timestep such that the model knows which timestep it is in.
    • To encode the timestep, we can use sinusoidal embedding (hand-crafted) or some other positional embeddings (learned from data)
    • The key idea is that each position has a unique positional vector

img