Paper Review - PDAE, DisDiff and InfoDiffusion
PDAE
Unsupervised Representation Learning from Pre-trained Diffusion Probabilistic Models
NIPS 2022
Code: https://github.com/ckczzj/PDAE
Diffusion AutoEncoders
- Unlike LDM that focuses on generative modeling and compressing with autoencoder, PDAE centers around representation learning.
- PDAE focuses on training its encoder and a “gradient estimator” decoder to work with the pre-trained diffusion model. This approach leverages the power of a pre-trained DPM for representation learning.
PDAE’s training strategy centers around minimizing the “posterior mean gap” present in the pre-trained diffusion model. The encoder and gradient estimator in PDAE are trained to specifically address this gap, which leads to a more semantically meaningful latent space.
Posterior Mean Gap
- Forward process of diffusion leads to a loss of information.
- Reverse process of diffusion predicts the mean of the latent variable (noise)
- might not accurately predict the true posterior mean of the latent variable, leading to the posterior mean gap.
The goal of paper is to lead to more effective and efficient representation learning.
Method
Section 3.2 of the paper
PDAE minimize the discrepancy between the predicted and true means = “Unsupervised Representation Learning by Filling the Gap”, done by:
- Training an Encoder that produce the representation
- The encoder learns to extract meaningful representations from the input image
- simply stacked convolution layers and a linear layer is enough to learn meaningful from (found to be better than UNet encoder)
- Training a gradient esimator
- The gradient estimator, taking the encoded representation as input, predicts a “mean shift” that corrects the pre-trained DPM’s predictions during the reverse process
- essentially attempts to predict how the latent representation would influence the generation process at each step of the reverse diffusion process
- middle blocks, decoder part and output blocks of U-Net
- The gradient estimator, taking the encoded representation as input, predicts a “mean shift” that corrects the pre-trained DPM’s predictions during the reverse process
- Gray part: A Pre-trained diffusion model (Frozen)
- Green part: Gradient estimator to predict the mean shift for correction
- is incorporated by applying Group Normalization with scaling and shifting (equation 9 in the paper)
- Blue part: Encoder that produce
During the training of the encoder and decoder, it is treated like a regular conditional diffusion model (note that is frozen):
- : noise added during the forward diffusion process
- : the noise predicted by the pre-trained DPM
- : terms related to the variance schedule in the diffusion process
- : the variance of the pre-trained DPM.
- : the gradient estimator, taking the noisy image (), encoded representation (z), and timestep (t) as input.
The optimization is equivalent to minimizing which is the difference between predicted gradient and the predicted noise of pretrained diffusion model.
was originally set as 1 but then they found the training became extremely unstable and result in non-convergence. They investigate and found that the mean shift during critical-stage contains more crucial information to reconstruct the input class label in samples than the other two stages. (Figure 3)
Therefore they design a weighting scheme in terms of signal-to-noise ratio
where the first item is for early-stage and the second one is for late-stage. to balances the strength of down-weighting between two items. So this weighting scheme down-weights the diffusion loss for both low and high SNR.
Insights
Average posterior mean gap for all steps. PDAE predicts the mean shift that significantly fills the posterior mean gap.
- Better Training Efficiency compare to prior work (Diff-AE)
- Owing to the reuse of the U-Net encoder part of pre-trained DPM, PDAE has less trainable parameters and achieves a higher training throughput than Diff-AE
- Modeling the posterior mean gap based on pre-trained DPMs is easier than modeling a conditional DPM from scratch
- PDAE leverages pre-trained DPMs for better representation learning.
DisDiff
DisDiff: Unsupervised Disentanglement of Diffusion Probabilistic Models
NIPS 2023
Code: https://github.com/thomasmry/DisDiff
The code is a bit messy and its built upon LDM codebase.
- Disentangling a DPM into several disentangled sub-gradient fields, which can improve the interpretability of DPM.
- DisDiff is an unsupervised framework which learns a disentangled representation and a disentangled gradient field for each factor.
- Enforcing constraints on representation through the diffusion model’s classifier guidance and score-based conditioning trick.
Method
DisDiff is based on Latent Diffusion Model.
Given a pre-trained DDPM model, the target is to disentangle the DPM in an unsupervised manner.
Given an input the dataset , for each factor factors , the goal is to learn the disentangled representation and its corresponding disentangled gradient field .
A encoder with learnable parameters is employed to obtain all representations.
A decoder is used to estimate the gradient fields .
where
The training loss of this disentangle diffusion model would be:
Where it is similar to PDAE (Pre-trained DPM AutoEncoding) loss but the gradient field is replaced with the summation of estimated gradient fields.
(a) The networks architecture of DisDiff
- Grey Networks : Pretrained UNet of DPM
- Image is encoded to representations of different factors by encoder
- then decoded by decoder to obtain the gradient field
(b) The demonstration of disentangling loss
- First sample a factor and decode it into gradient field, allowing us to obtain predicted of that factor.
- At the same time the predicted of original pre-trained DPM is obtained
- Encode the images into two different representations and calculate the disentangling loss based on them.
Disentangling Loss
The disentangling loss is optimized w.r.t. Decoder but not Encoder .
Disentangling Loss consists of two components: Invariant Loss and Variant Loss.
Both of them are Cross-entropy loss.
Cross-entropy loss penalizes incorrect classifications heavily, especially when the predicted probability is far from the actual class.
To minimize the upper bound (estimator) for the mutual information between and :
where
- is the initial image
- is the reconstructed image (unconditioned)
- is the reconstructed image conditioned on
- and
- and
Invariant Loss
- Encourages in-variance of representation not being sampled, which means that the sampled factor will not affect the representation of other factors during the generation process.
- Mainly promotes the disentanglement
Implementation of Invarient Loss:
- ,
- CrossEntropy loss to identify the index , which minimizes the distances at other indexes.
Variant Loss
- Encourages the representation of the factor being conditioned upon to be close to the corresponding representation of the input data.
Implementation of Varient Loss:
- ,
- ,
- CrossEntropy loss to maximize the distances at indexes but minimize the distance at index ,
Using score-based conditioning trick, can be approximated efficently.
Total Loss
Insights
- DisDiff Achieve much better TAD and FID metrics on CelebA compared to -VAE, InfoVAE, and slightly better results compared to Diffae and InfoDiffusion
- The decomposition of the gradient field of the diffusion model derives the disentanglement of DisDiff. Therefore, the diffusion space influences the performance of DisDiff. It is hard for the model to learn shape and scale in image space, but it is much easier in the latent space of the auto-encoder. Therefore, LDM-version DisDiff outperforms the image-version DisDiff.
CLIP-guided DisDiff
By leveraging knowledge learned from pre-trained CLIP, DisDiff can learn meaningful gradient fields.
InfoDiffusion
Paper: InfoDiffusion: Representation Learning Using Information Maximizing Diffusion Models
PMLR 2023
Code: No Code