Overview on common Generative adversarial network methods
GAN
Loss function
Loss function of Discriminator
Where
- is discriminator
- is generator
- is training real sample
- is random noise
Since is training real sample,we want the term output 1. will be 0 because .
Since is random noise, we want the to output 0 or 1.
From the discriminator’s point of view, we want to output 0 at the term . So . (Not fooled by generator). If it is fooled by generator and output close to 1 at the term , the term will be a small negative number close to 0.
The Discriminator want to maximize this expression.
Loss function of Generator
Where
- is discriminator
- is generator
- is training real sample
- is random noise
From the generator’s point of view, It wants to fool the discriminator into believing that the term is close to 1 (actually real).
The Generator want to minimize this expression.
In practice, Generator is trained to instead:
This new loss function leads to non-saturating gradients, suggested by original GAN paper
Putting the loss functions together
In implementation, it is a BCE loss with logit.
Question Raised
Are GAN sensitive to hyperparameters?
Yes
What is the best learning rate for Adam optimizer?
- According to Andrej Karpathy, 3e-4 (0.0003) is the best learning rate.
- But:
- for DCGAN paper, it uses 2e-4. (0.0002)
- There are always exceptions!
Can we use different learning rate for generator and discriminator?
Yes but most papers use the same learning rate for generator and discriminator.
Why JS divergence has gradients issues?
Note the GAN Loss used so far is equivalent to JS divergence.
However, JS divergence is not suitable and the loss gives no information.
JS divergence is always log2 if two distributions do not overlap. In most cases, The two distributions are not overlapped. That means if two distributions do not overlap, binary classifier achieves 100% accuracy. In such case, the accuracy (or loss) means nothing during the GAN training.
DCGAN
Architecture guidelines for stable Deep Convolutional GANs
- Replace any pooling layers with (discriminator) and fractional-strided convolutions (generator).
- Use batch normalization in both the generator and discriminator
- Remove fully connected hidden layers for deeper architectures
- Use ReLU activation in generator for all layers except for the output, which uses Tanh.
- Use LeakyReLU activation in the discriminator for all layers.
Experiment from DCGAN paper
- All models were trained with mini-batch stochastic gradient descent (SGD) with a mini-batch size of 128.
- All weights were initalized from a zero-centered Normal distribution with standard deviation 0.02.
- In the LeakyReLU, the slope of the leak was set to 0.2 in all models.
- Adam optimizer used (alpha learning rate 0.0002, beta1 momentum 0.5 stabilize training)
Problem of DCGAN / GAN
-
Training is unstable because generator and discriminator should have a balanaced match
-
mode collapse may happen
-
GAN Loss (JS divergence) is always log2 if two distributions do not overlap. In most cases, The two distributions are not overlapped because it is a high-dimensional space.
WGAN
Why WGAN (Wasserstein GAN)? Pros and Cons of WGAN:
- Pros:
- Better Stability (Prevent Mode Collapse)
- Loss now means something: Termination criteria
- Cons:
- Might takes longer to train?
WGAN uses Wasserstein Distance.
- Where is real distribution, is generated distribution
- Discriminator wants to separate and as much as possible
- Maximize the whole expression
- Generator wants to put and closer to each other
- Minimize the whole expression
Quickly comparing JS divergence to Wasserstein Distance:
Implementation of WGAN
1-Lipschitz function
- Force the parameters between and (Weight Clipping in a small range)
- Why? Because without constraint, the training of Discriminator will not converge.
- Uses RMSProp instead of Adams
- Does not need BCE loss (So the Sigmoid function of Discriminator is also gone)
Improved WGAN (WGAN-GP)
Improved Training of Wasserstein GANs
Weight clipping is a clearly terribly way to enforce a Lipschitz constraint. Why?
If the clipping parameter is large, then it can take a long time for any weights to reach their limit, thereby making it harder to train the critic till optimality. If the clipping is small, this can easily lead to vanishing gradients when the number of layers is big, or batch normalization is not used.
-
Uses Gradient Penalty to keep the gradients with norm less than or equal to 1
-
Batch norm in discriminator is replaced by instance norm or layer norm
-
Adam optimizer is used
-
Momentum not used (beta1 = 0)
1 | def gradient_penalty(discriminator, real, fake, device="cpu"): |
LSGAN
Paper: Least Squares Generative Adversarial Networks
Key feature:
- proposed Least Squares Generative Adversarial Networks (LSGANs) which adopt the least squares loss function (MSE) for the discriminator.
- able to generate higher quality images than regular GANs
- improved stability of learning process
Loss function of LSGAN
where:
- is the label for fake sample
- is the label for real sample
- denotes the value that the Generator wants the Discriminator to believe for a fake sample.
Benefits of LSGAN (Why?)
Note the regular GANs use sigmoid cross entropy loss function for the discriminator.
The LSGAN paper pointed out that regular GANs with sigmoid cross entropy loss function will lead to the problem of vanishing gradients when updating the generator using the fake samples that are on the correct side of the decision boundary, but are still far from the real data.
- regular GANs cause almost no loss for samples that lie in a long way on the correct side of the decision boundary
LSGANs will penalize those samples even though they are correctly classified
- the penalization will make the generator to generate samples toward the decision boundary
- moving the generated samples toward the decision boundary leads to making them be closer to the manifold of real data
Implementation of LSGAN
- Unplug the sigmoid function from discriminator
- Uses MSE to replace BCE loss from regular GANs
SN-GAN
Paper: Spectral Normalization for Generative Adversarial Networks
Key features of SNGAN:
- Spectral normalization for discriminator
- Stabilize the training of the discriminator
- Generated examples are more diverse
Supp
Must know information
You Train Generator and Discriminator together at once.
- A typical GAN alternates between training the discriminator and training the generator.
- NOT trains the generator and the discriminator simultaneously!
- While it’s possible for a GAN to use the same loss for both generator and discriminator training (or the same loss differing only in sign), it’s not required. In fact it’s more common to use different losses for the discriminator and the generator.
- During generator training, gradients propagate through the discriminator network to the generator network (although the discriminator does not update its weights during generator training). So the weights in the discriminator network influence the updates to the generator network.
- GAN is hard to train, because the Generator and Discriminator needs to match each other
- In the end of training, the Generator generates dollar bills indistinguishable from real ones and the Discriminator is forced to guess with probability = 0.5
- training of GANs benefit dramatically from large batch sizes.
Definition of Training Steps and Training Epochs
- Epoch: A training epoch represents a complete use of all training data for gradients calculation and optimizations.
- Step: A training step means using one batch size of training data to train the model.
Why do we need Z input in Generative model?
When the tasks needs “Creativity” (Same input but different output), we need a distribution. Generator will use Z (usually normal distribution) to produce a complex distribution.
Divergence in Generator
- Used to measure the distance between 2 distributions (Whether they are similar)
- Lower Divergence = more similar
- Although we do not know the distributions, we can sample from them to compute the divergence.
Objective Function in Discriminator
- Hope for largest Objective function
- Related to Divergence
But how to define the distance between two probability distributions?
The most common ways:
- KL divergence (Kullback-Leibler divergence)
- JS divergence (Jensen-Shannon divergence) <= GAN
- Wasserstein Distance <= WGAN
In summary, JS divergence has gradients issues leading to unstable training, and WGAN instead bases its loss from Wasserstein Distance
Potential Problem exists in GAN
Mode Collapse
- generated data produce always same image
- Can be observed easily
- No solution, cannot be avoided (Use checkpoint before Mode collapse)
Mode Dropping
- generated data is only a variation of true data
- Cannot be observed easily
Evaluation methods
Evaluating the quality of synthesized images is an open and difficult problem. Traditional metrics such as per-pixel mean-squared error do not assess joint statistics of the result, and therefore do not measure the very structure that structured losses aim to capture.
Human Preference
A Human evaluation method mentioned in the Pix2Pix paper. (“real vs. fake” perceptual studies on Amazon Mechanical Turk (AMT))
- A perceptual test for human observers (As plausibility to a human observer is often the ultimate goal)
- Detail of perceptual test:
- Turkers were presented with a series of trials that pitted a “real” image against a “fake” image generated by Pix2Pix. On each trial, each image appeared for 1 second, after which the images disappeared and Turkers were given unlimited time to respond as to which was fake.
- The first 10 images of each session were practice and Turkers were given feedback. No feedback was provided on the 40 trials of the main experiment. Each session tested just one algorithm at a time, and Turkers were not allowed to complete more than one session.
- 50 Turkers evaluated each algorithm.
- Human evaluation is expensive (and sometimes unfair/unstable)
- How to evaluate the quality of the generated images automatically?
Inception Score (IS)
Using a pre-trained deep learning neural network model for image classification (InceptionV3) to classify the generated images, we can get the inception score.
- High score from CNN Inception Network ⇒ Good image quality, large image diversity
- Not suitable for all situation
Frechet Inception Distance (FID)
Empirically estimates the distribution of real and generated images in a deep network space and computes the divergence between them. Intuitively, if the generated images are realistic, they should have similar summary statistics as real images, in any feature space.
- FID is not perfect. It doesn’t capture the conditioning (i.e. alignment between output and input). But it captures the marginal distribution.
- Take the layer before softmax from CNN Inception Network (or any other CNN networks)
- Compute the Frechet distance between two gaussian distribution
- Assumed the two distribution are gaussian
- Smaller distance better (Note very small FID = same as real data = not what we want)
The common practice is to calculate the FIDs using 50,000 images drawn randomly from the training set, and report the lowest distance encountered over the course of training.
(FID) is defined by the Frechet distance between feature vectors from the real and generated images based on the Inception-v3 pool3 layer. Lower FID indicates better perceptual quality.