PDAE

Unsupervised Representation Learning from Pre-trained Diffusion Probabilistic Models

NIPS 2022

Code: https://github.com/ckczzj/PDAE

Diffusion AutoEncoders

  • Unlike LDM that focuses on generative modeling and compressing with autoencoder, PDAE centers around representation learning.
  • PDAE focuses on training its encoder and a “gradient estimator” decoder to work with the pre-trained diffusion model. This approach leverages the power of a pre-trained DPM for representation learning.

PDAE’s training strategy centers around minimizing the “posterior mean gap” present in the pre-trained diffusion model. The encoder and gradient estimator in PDAE are trained to specifically address this gap, which leads to a more semantically meaningful latent space.

Posterior Mean Gap

  • Forward process of diffusion leads to a loss of information.
  • Reverse process of diffusion predicts the mean of the latent variable (noise)
    • might not accurately predict the true posterior mean of the latent variable, leading to the posterior mean gap.

The goal of paper is to lead to more effective and efficient representation learning.

Method

Section 3.2 of the paper

PDAE minimize the discrepancy between the predicted and true means = “Unsupervised Representation Learning by Filling the Gap”, done by:

  • Training an Encoder that produce the representation z=Eφ(x0)z = E_\varphi(x_0)
    • The encoder learns to extract meaningful representations from the input image
    • simply stacked convolution layers and a linear layer is enough to learn meaningful zz from x0x_0 (found to be better than UNet encoder)
  • Training a gradient esimator Gψ(xt,Eφ(x0),t)G_\psi(x_t, E_\varphi(x_0), t)
    • The gradient estimator, taking the encoded representation as input, predicts a “mean shift” that corrects the pre-trained DPM’s predictions during the reverse process
      • essentially attempts to predict how the latent representation would influence the generation process at each step of the reverse diffusion process
    • middle blocks, decoder part and output blocks of U-Net

img

  • Gray part: A Pre-trained diffusion model (Frozen)
  • Green part: Gradient estimator to predict the mean shift for correction
    • zz is incorporated by applying Group Normalization with scaling and shifting (equation 9 in the paper)
  • Blue part: Encoder that produce zz

During the training of the encoder and decoder, it is treated like a regular conditional diffusion model (note that θ\theta is frozen):

L(ψ,φ)=Ex0,t,ϵ[λtϵϵθ(xt,t)+αt1αˉtβtΣθ(xt,t)Gψ(xt,Eφ(x0),t)2]\mathcal{L}(\psi, \varphi)=\mathbb{E}_{\boldsymbol{x}_0, t, \epsilon}\left[\lambda_t\left\|\epsilon-\boldsymbol{\epsilon}_\theta\left(\boldsymbol{x}_t, t\right)+\frac{\sqrt{\alpha_t} \sqrt{1-\bar{\alpha}_t}}{\beta_t} \cdot \boldsymbol{\Sigma}_\theta\left(\boldsymbol{x}_t, t\right) \cdot \boldsymbol{G}_\psi\left(\boldsymbol{x}_t, \boldsymbol{E}_{\varphi}\left(\boldsymbol{x}_0\right), t\right)\right\|^2\right]

  • ϵ\epsilon : noise added during the forward diffusion process
  • ϵθ(xt,t)\boldsymbol{\epsilon}_\theta\left(\boldsymbol{x}_t, t\right) : the noise predicted by the pre-trained DPM
  • αt,α^t,βt\alpha_t, \hat{\alpha}_t, \beta_t: terms related to the variance schedule in the diffusion process
  • Σθ(xt,t)\boldsymbol{\Sigma}_\theta\left(\boldsymbol{x}_t, t\right): the variance of the pre-trained DPM.
  • Gψ(xt,Eφ(x0),t)\boldsymbol{G}_\psi\left(\boldsymbol{x}_t, \boldsymbol{E}_{\varphi}\left(\boldsymbol{x}_0\right), t\right): the gradient estimator, taking the noisy image (xtx_t), encoded representation (z), and timestep (t) as input.

The optimization is equivalent to minimizing Σθ(xt,t)Gψ(xt,Eφ(x0),t)(μ~t(xt,x0)μθ(xt,t))2\left\|\boldsymbol{\Sigma}_\theta\left(\boldsymbol{x}_t, t\right) \cdot \boldsymbol{G}_\psi\left(\boldsymbol{x}_t, \boldsymbol{E}_{\varphi}\left(\boldsymbol{x}_0\right), t\right)-\left(\widetilde{\boldsymbol{\mu}}_t\left(\boldsymbol{x}_t, \boldsymbol{x}_0\right)-\boldsymbol{\mu}_\theta\left(\boldsymbol{x}_t, t\right)\right)\right\|^2 which is the difference between predicted gradient and the predicted noise of pretrained diffusion model.

λt\lambda_t was originally set as 1 but then they found the training became extremely unstable and result in non-convergence. They investigate and found that the mean shift during critical-stage contains more crucial information to reconstruct the input class label in samples than the other two stages. (Figure 3)

img

Therefore they design a weighting scheme in terms of signal-to-noise ratio SNR(t)=αˉt1αˉtSNR(t)=\frac{\bar{\alpha}_t}{1-\bar{\alpha}_t}

λt=(11+SNR(t))1γ(SNR(t)1+SNR(t))γ\lambda_t = (\frac{1}{1+SNR(t)})^{1-\gamma} \cdot (\frac{SNR(t)}{1+SNR(t)})^\gamma

where the first item is for early-stage and the second one is for late-stage. γ=0.1\gamma = 0.1 to balances the strength of down-weighting between two items. So this weighting scheme down-weights the diffusion loss for both low and high SNR.

Insights

img

Average posterior mean gap for all steps. PDAE predicts the mean shift that significantly fills the posterior mean gap.

  • Better Training Efficiency compare to prior work (Diff-AE)
    • Owing to the reuse of the U-Net encoder part of pre-trained DPM, PDAE has less trainable parameters and achieves a higher training throughput than Diff-AE
    • Modeling the posterior mean gap based on pre-trained DPMs is easier than modeling a conditional DPM from scratch
  • PDAE leverages pre-trained DPMs for better representation learning.

DisDiff

DisDiff: Unsupervised Disentanglement of Diffusion Probabilistic Models

NIPS 2023

Code: https://github.com/thomasmry/DisDiff

The code is a bit messy and its built upon LDM codebase.

  • Disentangling a DPM into several disentangled sub-gradient fields, which can improve the interpretability of DPM.
  • DisDiff is an unsupervised framework which learns a disentangled representation and a disentangled gradient field for each factor.
  • Enforcing constraints on representation through the diffusion model’s classifier guidance and score-based conditioning trick.

Method

DisDiff is based on Latent Diffusion Model.

Given a pre-trained DDPM model, the target is to disentangle the DPM in an unsupervised manner.

Given an input x0x_0 \in the dataset DD, for each factor cc \in factors CC , the goal is to learn the disentangled representation zcz^c and its corresponding disentangled gradient field xtlogp(zcxt)\nabla_{x_t} \log p(z^c|x_t).

A encoder EϕE_\phi with learnable parameters is employed to obtain all zcz^c representations.

Eϕ(x0)={Eϕ1(x0),Eϕ2(x0),,EϕN(x0)}={z1,z2,,zN}E_\phi\left(x_0\right)=\left\{E_\phi^1\left(x_0\right), E_\phi^2\left(x_0\right), \ldots, E_\phi^N\left(x_0\right)\right\}=\left\{z^1, z^2, \ldots, z^N\right\}

A decoder Gψ(xt,zc,t)G_\psi\left(x_t, z^c, t\right) is used to estimate the gradient fields xtlogp(zcxt)\nabla_{x_t} \log p(z^c|x_t).

xtlogp(zSxt)=cSxtlogp(zcxt)\nabla_{x_t} \log p\left(z^S \mid x_t\right)=\sum_{c \in S} \nabla_{x_t} \log p\left(z^c \mid x_t\right)

where zS={zccS}z^S = \{z^c| c\in S\}

img

The training loss of this disentangle diffusion model would be:

Lr=Ex0,t,ϵϵϵθ(xt,t)+αt1αˉtβtσtcCGψ(xt,zc,t)\mathcal{L}_r=\underset{x_0, t, \epsilon}{\mathbb{E}}\left\|\epsilon-\epsilon_\theta\left(x_t, t\right)+\frac{\sqrt{\alpha_t} \sqrt{1-\bar{\alpha}_t}}{\beta_t} \sigma_t \sum_{c \in \mathcal{C}} G_\psi\left(x_t, z^c, t\right)\right\|

Where it is similar to PDAE (Pre-trained DPM AutoEncoding) loss but the gradient field is replaced with the summation of estimated gradient fields.

img

(a) The networks architecture of DisDiff

  • Grey Networks : Pretrained UNet of DPM
  • Image x0x_0 is encoded to representations {z1,z2,,zN}\{z^1, z^2, \cdots, z^N\} of different factors by encoder
  • then decoded by decoder to obtain the gradient field

(b) The demonstration of disentangling loss

  • First sample a factor cc and decode it into gradient field, allowing us to obtain predicted x0x_0 of that factor.
  • At the same time the predicted x^0\hat{x}_0 of original pre-trained DPM is obtained
  • Encode the images into two different representations and calculate the disentangling loss based on them.

Disentangling Loss

The disentangling loss is optimized w.r.t. Decoder GG but not Encoder EE.

Disentangling Loss consists of two components: Invariant Loss and Variant Loss.

Both of them are Cross-entropy loss.

Cross-entropy loss penalizes incorrect classifications heavily, especially when the predicted probability is far from the actual class.

To minimize the upper bound (estimator) for the mutual information between zcz^{c} and zkcz^{k\neq c}:

minEk,c,x0,x^0cz^kcz^k+z^cczcz^kczk\min \underset{k, c, x_0, \hat{x}_0^c}{\mathbb{E}}\left\|\hat{z}^{k \mid c}-\hat{z}^k\right\|+ \left\|\hat{z}^{c \mid c}-z^c\right\|-\left\|\hat{z}^{k \mid c}-z^k\right\|

where

  • x0x_0 is the initial image
  • x^0\hat{x}_0 is the reconstructed image (unconditioned)
  • x^0c\hat{x}^c_0 is the reconstructed image conditioned on zcz^c
  • z^k=Eϕk(x^0)\hat{z}^k = E^k_{\phi}(\hat{x}_0) and z^kc=Eϕk(x^0c)\hat{z}^{k|c} = E^k_\phi(\hat{x}^c_0)
  • z^c=Eϕc(x^0)\hat{z}^c = E^c_{\phi}(\hat{x}_0) and z^cc=Eϕc(x^0c)\hat{z}^{c|c} = E^c_\phi(\hat{x}^c_0)

Invariant Loss Lin\mathcal{L}_{i n}

  • Encourages in-variance of representation not being sampled, which means that the sampled factor will not affect the representation of other factors during the generation process.
  • Mainly promotes the disentanglement

Implementation of Invarient Loss:

  • dk=z^kcz^kd_k = \left\|\hat{z}^{k \mid c}-\hat{z}^k\right\|, d=[d1,d2,...,dc]d = [d_1, d_2, ... ,d_c]
  • CrossEntropy loss to identify the index cc, which minimizes the distances at other indexes.

Lin=Ec,x^0,x^0c[ CrossEntropy (d,c)]\mathcal{L}_{i n}=\underset{c, \hat{x}_0, \hat{x}_0^c}{\mathbb{E}}[\text { CrossEntropy }(d, c)]

Variant Loss Lva\mathcal{L}_{va}

  • Encourages the representation of the factor being conditioned upon to be close to the corresponding representation of the input data.

Implementation of Varient Loss:

  • dkn=z^kzkd^n_k = \left\|\hat{z}^{k}-z^k\right\|, dn=[d1n,d2n,...,dNn]d^n = [d^n_1, d^n_2, ... ,d^n_N]
  • dkp=z^kczkd^p_k = \left\|\hat{z}^{k \mid c}-z^k\right\|, dp=[d1p,d2p,...,dNp]d^p = [d^p_1, d^p_2, ... ,d^p_N]
  • CrossEntropy loss to maximize the distances at indexes kk but minimize the distance at index cc,

Lva=Ec,x0,x^0,x^0c[ CrossEntropy (dndp,c)]\mathcal{L}_{v a}=\underset{c, x_0, \hat{x}_0, \hat{x}_0^c}{\mathbb{E}}\left[\text { CrossEntropy }\left(d^n-d^p, c\right)\right]

Using score-based conditioning trick, x^0c\hat{x}^c_0 can be approximated efficently.

Total Loss

La=Lr+λx^0x^0c2(Lin+Lva)\mathcal{L}_a=\mathcal{L}_r+\lambda\left\|\hat{x}_0-\hat{x}_0^c\right\|^2\left(\mathcal{L}_{i n}+\mathcal{L}_{v a}\right)

Insights

  • DisDiff Achieve much better TAD and FID metrics on CelebA compared to β\beta-VAE, InfoVAE, and slightly better results compared to Diffae and InfoDiffusion
  • The decomposition of the gradient field of the diffusion model derives the disentanglement of DisDiff. Therefore, the diffusion space influences the performance of DisDiff. It is hard for the model to learn shape and scale in image space, but it is much easier in the latent space of the auto-encoder. Therefore, LDM-version DisDiff outperforms the image-version DisDiff.

CLIP-guided DisDiff

img

By leveraging knowledge learned from pre-trained CLIP, DisDiff can learn meaningful gradient fields.

InfoDiffusion

Paper: InfoDiffusion: Representation Learning Using Information Maximizing Diffusion Models

PMLR 2023

Code: No Code

Method