Score-based Generative Modeling with Differential Equations

“Score-Based Generative Modeling through Stochastic Differential Equations”

First lets revisit the Forward Diffusion Process:

q(xtxt1)=N(xt;1βtxt1,βtI)q(x_t|x_{t-1})=\mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_tI)

Now consider the limit of many small steps:

q(xtxt1)=N(xt;1βtxt1,βtI)q(x_t|x_{t-1})=\mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_tI)

Carrying reparametrization trick:

xt=1βtxt1+βtN(0,I)x_t = \sqrt{1-\beta_t}x_{t-1} + \sqrt{\beta_t}\mathcal{N}(0,I)

We define the stepsize βt\beta_t as β(t)Δt\beta(t)\Delta t

xt=1β(t)Δtxt1+β(t)ΔtN(0,I)x_t = \sqrt{1-\beta(t)\Delta t} x_{t-1}+\sqrt{\beta(t)\Delta t }\mathcal{N}(0,I)

If there are many many timesteps β(t)\beta(t), the Δt\Delta t go towards to 0, we can perform Taylor expansion,

xtxt1β(t)Δt2xt1+β(t)ΔtN(0,I)x_t \approx x_{t-1} - \frac{\beta(t)\Delta t}{2}x_{t-1} + \sqrt{\beta(t)\Delta t}\mathcal{N}(0,I)

we can interpret this form as some iterative update that the new xtx_t is given by old term xt1x_{t-1}

minus some term depend on xt1x_{t-1} itself and some noise added.

The iterative update will correspond to a certain solution/discretization of a Stochastic Differential Equation, in particular of this:

dxt=12β(t)xtdt+β(t)dωtdx_t = -\frac{1}{2}\beta(t)x_tdt + \sqrt{\beta(t)}d\omega_t

Stochastic Differential Equation

  • Describing the diffusion in infinitesimal (extremely small) limit

Ordinary Differential Equation (ODE):

dxdt=f(x,t) or dx=f(x,t)dt\frac{dx}{dt}=f(x,t) \text{ or } dx= f(x,t)dt

  • xx is the state that we are interested in (e.g. pixel of the image)
  • tt is some continuous timer variable that captures the time along which this state xx changes/evolve
  • We can apply integration to the equation following the arrows to get the final expression xx of tt.
    • However in practice this ff function is often a highly complex nonlinear function (e.g. a neural network)

img

The Analytical Solution (Which cannot be found)

x(t)=x(0)+0tf(x,τ)dτx(t) = x(0) + \int^t_0 f(x, \tau)d\tau

Iterative Numerical Solution

x(t+Δt)x(t)+f(x(t),t)Δtx(t + \Delta t) \approx x(t) + f(x(t), t)\Delta t

Stochastic Differential Equation (SDE):

dxdt=f(x,t)+σ(x,t)ωt\frac{dx}{dt}= f(x,t) + \sigma(x,t)\omega_t

  • ωt\omega_t is called the wiener process (in practice Gaussian White Noise)
  • f(x,t)f(x,t) is the drift coefficient (which pull towards mode)
  • σ(x,t)\sigma(x,t) is the diffusion coefficient of the noise

img

To solve it, the answer is similar to Iterative Numerical Solution

x(t+Δt)x(t)+f(x(t),t)Δt+σ(x(t),t)ΔtN(0,I)x(t + \Delta t) \approx x(t) + f(x(t), t)\Delta t + \sigma(x(t), t) \sqrt{\Delta t}\mathcal{N}(0,I)

But since there is a diffusion coefficient with the wiener process, noise that proportional to the diffusion coefficient are added, so there is not a unique solution like the ODE case

img

Forward Diffusion Process as SDE

As mentioned, forward diffusion process can be a SDE:

dxt=12β(t)xtdt+β(t)dωtdx_t = -\frac{1}{2}\beta(t)x_tdt + \sqrt{\beta(t)}d\omega_t

  • f(x,t)f(x,t) is the drift coefficient (which pull towards mode)
  • σ(x,t)ωt\sigma(x,t)\omega_t is the diffusion coefficient (which injects noise)

img

Sepcial case of more general SDEs used in generative diffusion models:

dxt=f(t)xtdt+g(t)dωtdx_t = f(t)x_tdt + g(t)d\omega_t

Generative Reverse SDE

dxt=12β(t)[xtβ(t)xtlogqt(xt)]dt+β(t)dωˉtdx_t = -\frac{1}{2}\beta(t)[x_t - \beta(t)\nabla x_t \log q_t (x_t)] dt + \sqrt{\beta(t)}d\bar{\omega}_t

where

  • 12β(t)xtβ(t)xtlogqt(xt)-\frac{1}{2}\beta(t)x_t - \beta(t)\nabla x_t \log q_t (x_t) is the drift term
  • β(t)dωˉt\sqrt{\beta(t)}d\bar{\omega}_t is the diffusion term
  • Δxtlogqt(xt)\Delta x_t \log q_t (x_t) is the score function

But how to get the score function xtlogqt(xt)\nabla x_t \log q_t (x_t)?

  • Learn a neural network

Score Matching

minθEtU(0,T)Extqt(xt)sθ(xt,t)xtlogqt(xt)22\min _{\boldsymbol{\theta}} \mathbb{E}_{t \sim \mathcal{U}(0, T)} \mathbb{E}_{\mathbf{x}_t \sim q_t\left(\mathbf{x}_t\right)}\left\|\mathbf{s}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\nabla_{\mathbf{x}_t} \log q_t\left(\mathbf{x}_t\right)\right\|_2^2

where

  • EtU(0,T)\mathbb{E}_{t \sim \mathcal{U}(0, T)} is the diffusion time tt
  • Extqt(xt)\mathbb{E}_{\mathbf{x}_t \sim q_t\left(\mathbf{x}_t\right)} is the diffused data xtx_t
  • sθ(xt,t)\mathbf{s}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) is the neural network
  • xtlogqt(xt)\nabla_{\mathbf{x}_t} \log q_t\left(\mathbf{x}_t\right) is the score of diffused data (marginal)

But xtlogqt(xt)\nabla_{\mathbf{x}_t} \log q_t\left(\mathbf{x}_t\right) (score of the marginal diffused density qt(xt)q_t(x_t)) is not tractable

  • Instead, diffuse individual data points x0x_0. Diffused qt(xtx0)q_t(x_t|x_0) is tractable.

Denoising Score Matching

“Variance Preserving” SDE:

dxt=12β(t)xtdt+β(t)dωtdx_t = -\frac{1}{2}\beta(t)x_tdt + \sqrt{\beta(t)}d\omega_t

qt(xtx0)=N(xt;γtx0,σt2I)q_t(x_t|x_0) = \mathcal{N}(x_t;\gamma_tx_0, \sigma^2_tI)

γt=e120tβ(s)ds\gamma_t = e^{-\frac{1}{2}\int^t_0\beta(s)ds}

σt2=1e0tβ(s)ds\sigma^2_t = 1 - e^{-\int^t_0 \beta(s)ds}

minθEtU(0,T)Ex0q0(x0)Extqt(xtx0)sθ(xt,t)xtlogqt(xtx0)22\min _{\boldsymbol{\theta}} \mathbb{E}_{t \sim \mathcal{U}(0, T)} \mathbb{E}_{\mathbf{x}_0 \sim q_0\left(\mathbf{x}_0\right)} \mathbb{E}_{\mathbf{x}_t \sim q_t\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}|| \mathbf{s}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\nabla_{\mathbf{x}_t} \log q_t\left(\mathbf{x}_t \mid \mathbf{x}_0\right) \|_2^2

where

  • $\mathbb{E}_{t \sim \mathcal{U}(0, T)} $ is the diffusion time tt
  • Ex0q0(xt)\mathbb{E}_{\mathbf{x}_0 \sim q_0\left(\mathbf{x}_t\right)} is the data sample x0x_0
  • Extqt(xtx0)\mathbb{E}_{\mathbf{x}_t \sim q_t\left(\mathbf{x}_t|\mathbf{x}_0\right)} is the diffused data xtx_t
  • sθ(xt,t)\mathbf{s}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) is the neural network
  • xtlogqt(xtx0)\nabla_{\mathbf{x}_t} \log q_t\left(\mathbf{x}_t|\mathbf{x}_0\right) is the score of diffused data sample

=> After expectations,

sθ(xt,t)xtlogqt(xt)\mathbf{s}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) \approx \nabla_{\mathbf{x}_t} \log q_t(\mathbf{x}_t)

Implementation 1 : Noise Prediction

From minθEtU(0,T)Ex0q0(x0)Extqt(xtx0)sθ(xt,t)xtlogqt(xtx0)22\min _{\boldsymbol{\theta}} \mathbb{E}_{t \sim \mathcal{U}(0, T)} \mathbb{E}_{\mathbf{x}_0 \sim q_0\left(\mathbf{x}_0\right)} \mathbb{E}_{\mathbf{x}_t \sim q_t\left(\mathbf{x}_t \mid \mathbf{x}_0\right)}|| \mathbf{s}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)-\nabla_{\mathbf{x}_t} \log q_t\left(\mathbf{x}_t \mid \mathbf{x}_0\right) \|_2^2

Reparametrized sampling: xt=γtx0+σtϵx_t = \gamma_tx_0 + \sigma_t\epsilon , ϵN(0,I)\epsilon \sim \mathcal{N}(0,I)

Score function: xtlogqt(xtx0)=xt(xtγtx0)22σt2=xtγtx0σt2\nabla_{\mathbf{x}_t} \log q_t\left(\mathbf{x}_t|\mathbf{x}_0\right) = - \nabla_{\mathbf{x}_t} \frac{(x_t - \gamma_tx_0)^2}{2\sigma^2_t} = -\frac{x_t - \gamma_tx_0}{\sigma^2_t}

Neural network model: sθ(xt,t):=ϵθ(xt,t)σt\mathbf{s}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right) := \frac{\epsilon_\theta(x_t, t)}{\sigma_t}

=> which is basically passed with predicting noise values epsilon

minθEtU(0,T)Ex0q0(x0)EϵN(0,I)1σt2ϵϵθ(xt,t)22\min _{\boldsymbol{\theta}} \mathbb{E}_{t \sim \mathcal{U}(0, T)} \mathbb{E}_{\mathbf{x}_0 \sim q_0\left(\mathbf{x}_0\right)} \mathbb{E}_{\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})} \frac{1}{\sigma_t^2}\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)\right\|_2^2

If we our network can predict those noise values that were used for perturbation, then we can denoise and reconstruct original data point x0x_0 from xtx_t

Implementation 2 : Loss Weightings

Denoising Score Matching objective with loss weighting λ(t)\lambda(t):

minθEtU(0,T)Ex0q0(x0)EϵN(0,I)λ(t)σt2ϵϵθ(xt,t)22\min _{\boldsymbol{\theta}} \mathbb{E}_{t \sim \mathcal{U}(0, T)} \mathbb{E}_{\mathbf{x}_0 \sim q_0\left(\mathbf{x}_0\right)} \mathbb{E}_{\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})} \frac{\lambda(t)}{\sigma_t^2}\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)\right\|_2^2

Different loss weightings trade off between model with good perceptual quality vs high log-likelihood

  • Perceptual quality: λ(t)=σt2\lambda(t) = \sigma^2_t
  • Maximum log-likelihood: λ(t)=β(t)\lambda(t) = \beta(t) (negative ELBO)

Same objectives as derived with variational approach

Implementation 3 : Variance Reduction and Numerical Stability

minθEtU(0,T)Ex0q0(x0)EϵN(0,I)λ(t)σt2ϵϵθ(xt,t)22\min _{\boldsymbol{\theta}} \mathbb{E}_{t \sim \mathcal{U}(0, T)} \mathbb{E}_{\mathbf{x}_0 \sim q_0\left(\mathbf{x}_0\right)} \mathbb{E}_{\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})} \frac{\lambda(t)}{\sigma_t^2}\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)\right\|_2^2

Notice σt20\sigma^2_t \rightarrow 0 as t0t \rightarrow 0. Loss heavily amplified when sampling tt close to 00 (for λ(t)=β(t)\lambda(t)=\beta(t)) . High variance. So in that we wouldnt want to use implementation 2 right the way. We add some tricks instead.

Trick 1: Trail with small time cut-off η\eta (105\approx 10^{-5}):

minθEtU(η,T)Ex0q0(x0)EϵN(0,I)λ(t)σt2ϵϵθ(xt,t)22\min _{\boldsymbol{\theta}} \mathbb{E}_{t \sim \mathcal{U}(\eta, T)} \mathbb{E}_{\mathbf{x}_0 \sim q_0\left(\mathbf{x}_0\right)} \mathbb{E}_{\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})} \frac{\lambda(t)}{\sigma_t^2}\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)\right\|_2^2

Trick 2: Variance reduction by Importance Sampling:

  • Importance Sampling distribution: r(t)λ((t))σt2r(t) \propto \frac{\lambda((t))}{\sigma^2_t}

minθEtr(t)Ex0q0(x0)EϵN(0,I)1r(t)λ(t)σt2ϵϵθ(xt,t)22\min _{\boldsymbol{\theta}} \mathbb{E}_{t \sim r(t)} \mathbb{E}_{\mathbf{x}_0 \sim q_0\left(\mathbf{x}_0\right)} \mathbb{E}_{\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})} \frac{1}{r(t)}\frac{\lambda(t)}{\sigma_t^2}\left\|\boldsymbol{\epsilon}-\boldsymbol{\epsilon}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)\right\|_2^2

img

Probability Flow ODE

Consider reverse generative diffusion SDE: dxt=12β(t)[xt+2xtlogqt(xt)]dt+β(t)dωˉtdx_t = -\frac{1}{2}\beta(t)[x_t + 2\nabla x_t \log q_t (x_t)] dt + \sqrt{\beta(t)}d\bar{\omega}_t

In distribution equivalent to Probability Flow ODE:

dxt=12β(t)[xt+xtlogqt(xt)]dtdx_t = -\frac{1}{2}\beta(t)[x_t + \nabla_{\mathbf{x}_t} \log q_t(x_t)]dt

img

Probability Flow ODE: Diffusion Models as Continuous Normalizing Flows

so why should we care and why should we use this probability flow ODE framework? It turns out this ordinary differential equation allows the use of advanced ordinary differential equation solvers, therefore easier to work with ODE than SDE.

  • Enables use of advanced ODE solvers
  • Deterministic encoding and generation (semantic image interpolation, etc.)
    • Allow encoding datapoint in the latent space
      • Continuous changes in latent space xTx_T result in continuous, semantically meaningful changes in data space x0x_0
  • Log-likelihood computation (instantaneous change of variables):
    • logpθ(x0)=logpT(xT)0TTr(12β(t)xt[xt+sθ(xt,t)])dt\log p_{\boldsymbol{\theta}}\left(\mathbf{x}_0\right)=\log p_T\left(\mathbf{x}_T\right)-\int_0^T \operatorname{Tr}\left(\frac{1}{2} \beta(t) \frac{\partial}{\partial \mathbf{x}_t}\left[\mathbf{x}_t+\mathbf{s}_{\boldsymbol{\theta}}\left(\mathbf{x}_t, t\right)\right]\right) \mathrm{d} t
  • Diffusion models can be considered CNFs trained with score matching

Synthesis with SDE vs ODE

img

Trajectories are zigzagging in SDE and following the distirbution, while ODE is more deterministic trajectories. (Both land in the modes of the data distribution)

  • Generative Reverse Diffusion SDE (stochastic):
    • dxt=12β(t)[xt+2sθ(xt,t)]dt+β(t)dωˉtdx_t = -\frac{1}{2}\beta(t)[x_t + 2s_\theta(x_t,t)]dt + \sqrt{\beta(t)}d\bar{\omega}_t
  • Generative Probability Flow ODE (deterministic):
    • dxt=12β(t)[xt+sθ(xt,t)]dtdx_t = -\frac{1}{2}\beta(t)[x_t + s_\theta(x_t, t)]dt

Solving generative SDE or ODE in practice

Sampling from “Continuous-Time” Diffusion Models: How to solve the generative SDE or ODE in practice?

Generative Reverse Diffusion SDE (stochastic):

  • dxt=12β(t)[xt+2sθ(xt,t)]dt+β(t)dωˉtdx_t = -\frac{1}{2}\beta(t)[x_t + 2s_\theta(x_t,t)]dt + \sqrt{\beta(t)}d\bar{\omega}_t

Most naive way: Euler-Maruyama:

xt1=xt+12β(t)[xt+2sθ(xt,t)]Δt+β(t)ΔtN(0,I)x_{t-1} = x_t + \frac{1}{2}\beta(t)[x_t + 2s_\theta(x_t, t)] \Delta t+\sqrt{\beta(t)\Delta t} \mathcal{N}(0, I)

Generative Probability Flow ODE (deterministic):

  • dxt=12β(t)[xt+sθ(xt,t)]dtdx_t = -\frac{1}{2}\beta(t)[x_t + s_\theta(x_t, t)]dt

Naive way use Euler’s Method:

xt1=xt+12β(t)[xt+sθ(xt,t)]Δtx_{t-1} = x_t + \frac{1}{2}\beta(t)[x_t + s_\theta(x_t, t)]\Delta t

=> In practice: Higher-order ODE solvers (Ruge-Kutta, linear multistep methods, exponential integrators…)

What should we use to solve them?

Reconsider Generative DIffusion SDE: dxt=12β(t)[xt+2sθ(xt,t)]dt+β(t)dωˉtdx_t = -\frac{1}{2}\beta(t)[x_t + 2s_\theta(x_t,t)]dt + \sqrt{\beta(t)}d\bar{\omega}_t

we can actually decompose it into two terms:

dxt=12β(t)[xt+2sθ(xt,t)]dt12β(t)sθ(xt,t)dt+β(t)dωˉtdx_t = -\frac{1}{2}\beta(t)[x_t + 2s_\theta(x_t,t)]dt - \frac{1}{2}\beta(t)s_\theta(x_t, t)dt + \sqrt{\beta(t)}d\bar{\omega}_t

where

  • 12β(t)[xt+2sθ(xt,t)]dt-\frac{1}{2}\beta(t)[x_t + 2s_\theta(x_t,t)]dt is the Probability Flow ODE
  • 12β(t)sθ(xt,t)dt+β(t)dωˉt- \frac{1}{2}\beta(t)s_\theta(x_t, t)dt + \sqrt{\beta(t)}d\bar{\omega}_t is the Langevin Diffusion SDE

SDE vs ODE Sampling: Pro’s and Con’s

SDE Sampling:

  • Pro: Continuous noise injection can help to compensate errors during diffusion process (Langevin sampling actively pushes towards correct distribution).
  • Con: Often slower, because the stochastic terms themselves require fine discretization during solve.

ODE Sampling:

  • Pro: Can leverage fast ODE solvers. Best when targeting very fast sampling.
  • Con: No “stochastic” error correction, often slightly lower performance than stochastic sampling.

Diffusion Models as Energy-based Models

  • Assume an Energy-based Model (EBM): pθ(x,t)=eEθ(x,t)Zθ(t)p_{\boldsymbol{\theta}}(\mathbf{x}, t)=\frac{e^{-E_{\boldsymbol{\theta}}(\mathbf{x}, t)}}{\mathcal{Z}_{\boldsymbol{\theta}}(t)}
  • Sample EBM via Langevin dynamics: xi+1=xiηxEθ(xi,t)+2ηN(0,I)x_{i+1} = x_i - \eta\nabla_{\mathbf{x}}E_\theta(\mathbf{x}_i, t) + \sqrt{2\eta}\mathcal{N}(0,I)
  • Requires only gradient of energy xEθ(xi,t)\nabla_{\mathbf{x}}E_\theta(\mathbf{x}_i, t), not Eθ(x,t)E_\theta(x,t) itself, nor Zθ(t)\mathcal{Z}_{\boldsymbol{\theta}}(t)

In diffusion models, we learn “energy gradients” for all diffused distributions directly:

xlogqt(x)sθ(x,t)=:xlogpθ(x,t)=xlogZθ(t)\nabla_{\mathbf{x}}\log q_t(x) \approx s_\theta(x,t) =: \nabla_{\mathbf{x}}\log p_\theta(x,t) = -\nabla_{\mathbf{x}}\log \mathcal{Z}_{\boldsymbol{\theta}}(t)

where xlogZθ(t)=0\nabla_{\mathbf{x}}\log \mathcal{Z}_{\boldsymbol{\theta}}(t) = 0

=> Diffusion models model energy gradient directly, along entire diffusion process, and avoid modeling partition function. Different noise levels along diffusion are analogus to annealed sampling in EBMs.

Unique Identifiability of DIffusion models

The model is supposed approximate the score function of the diffused data qt(xt)q_t(x_t).

This denoising model is in principle uniquely determined by the data that we’re given and the forward diffusion process.

  • Denoising model sθ(xt,t)s_\theta(x_t, t) and deterministic data encodings uniquely determined by data and fixed forward diffusion
  • Even with different architectures and initializations, we recover identical model outputs and encoding (given sufficent training data, model capacity and optimization accuracy), in contrast to GANs, VAEs, etc.

img

Summary

Why use Differential Equation Framework?

Advantages of the Differential Equation framework for Diffusion models

  • Can leverage broad existing literature on advanced and fast SDE and ODE solvers when sampling from the model, which accelerate sampling from diffusion models which is very crucial because they can be slow
  • Allows us to construct deterministic Probability Flow ODE
    • Deterministic Data Encodings
    • Log-likelihood Estimation like Continuous Normalizing Flows, etc.
  • Clean mathematical framework based on Diffusion Processes and Score Matching; connections to Neural ODEs, Continuous Normalizing Flows and Energy-based Models