This post is a summary of some of my notes on vanilla generative adversarial networks (GANs).
What is a generative adversarial network?
What is a generative adversarial network? It is a neural network architecture in which two separate neural networks are trained using each other’s outputs, in an adversarial game of sorts.
One network is called the generator $G$, and tries to generate samples that the other network, the discriminator $D$, classifies as “real” (where “real” means from the same distribution as some training set). $D$ tries to accurately discriminate between real and generated samples.
Example setup and overview of training process
Consider the following potential setup of a GAN:
- We have a training set of some data (ex. images of faces), whose distribution we want to accurately model with $G$ so that $G$ can generate realistic samples from it
- Generator network $G$:
- Input: random noise $z$
- Output: some sample $G(z)$, ex. an image
- Discriminator network $D$:
- Input: either $G(z)$, or some “real” sample $x$
- Output: a value between 0 and 1 representing its confidence that its input is “real”, rather than generated by $G$. The value increases as $D$ becomes more confident the input is real, and decreases as $D$ becomes more confident the input is fake.
What might a single training loop look like?
- Generate a sample $G(z)$
- Feed $G(z)$ and real sample $x$ into $D$ to get $D(G(z))$ and $D(x)$. Use them to calculate discriminator loss.
- Use $D(G(z))$ to calculate generator loss.
- Calculate gradients of $D$ and $G$ based on their losses and update their weights accordingly.
Loss functions
The loss functions commonly used in vanilla GANs are based on cross-entropy.
Minimax loss
In the original GANs paper “Generative Adversarial Nets” (Goodfellow et al., 2014), the following loss functions are used for the discriminator and generator:
Discriminator loss = $-E_{x\sim p_{data}(x)}[\log{D(x)}] - E_{z\sim p_z(z)}[\log(1 - D(G(z)))]$
- $-E_{x\sim p_{data}(x)}[\log{D(x)}]$ can be interpreted as the “real” loss - how badly D classifies real samples as real. In our setup, it’s the binary cross-entropy between 1 (the “true” value of $D(x))$ and the output $D(x)$ for real input $x$
- $-E_{z\sim p_z(z)}[\log(1 - D(G(z)))]$ can be interpreted as the “fake” loss - how badly D classifies fake samples as fake. In our setup, it’s the binary cross-entropy between 0 (the “true” value of $D(G(z)))$ and the output $D(G(z)))$
Generator loss = $E_{x\sim p_{data}(x)}[\log{D(x)}] + E_{z\sim p_z(z)}[\log(1 - D(G(z)))]$
- This is just the negative of the discriminator loss
- Since $G$ has no effect over $D$, we effectively just try to minimize $E_{z\sim p_z(z)}[\log(1 - D(G(z)))]$
- $E_{z\sim p_z(z)}[\log(1 - D(G(z)))]$ can be interpreted as a measure of how badly $G$ fools $D$ into thinking that its generated samples $G(z)$ are real. In our setup, it’s the negative binary cross-entropy between 0 (the “true” value of $D(G(z)))$ and the output $D(G(z)))$
The paper shows that this is minimaxing a Jensen-Shannon divergence (see paper for proof).
To help understand this section, recall that:
- In our setup, $D(x)$ outputs values between 0 and 1, with value 1 meaning it is completely confident $x$ is real, and 0 meaning it is completely confident $x$ is fake
- Cross-entropy of distributions $p$ and $q$ $: -\underset{x \in X}{\sum}p(x)\log(q(x)) = -E_{p(X)}[\log(q(X))]$
- Binary cross-entropy between $y_{true}$ and $y_{predicted}: y_{true} * -\log(y_{predicted}) + (1 - y_{true}) * -\log(1 - y_{predicted})$
Modified minimax loss
The aforementioned original GANs paper states that the original minimax loss may not provide sufficient gradient for G to learn well in practice, and proposes an alternative to the generator loss.
Modified generator loss = $-E_{z\sim p_z(z)}[\log(D(G(z)))]$
- Can be interpreted as a measure of how badly $G$ fools $D$ into thinking that its generated samples $G(z)$ are real. It’s the binary cross-entropy between 1 (the “true” value of $D(G(z)))$ and the output $D(G(z)))$
Why is this better than previous loss of $E_{z\sim p_z(z)}[\log(1 - D(G(z)))]$?
- In the early stages of training when $G$ is not good at producing good samples yet, $D$ can very confidently reject its samples and output very low values for $D(G(z))$. This causes the gradients for $log(1 - D(G(z))$ to be relatively very weak compared to those of higher values of $D(G(z))$.
- If we use $-log(D(G(z))$ instead, then we see the opposite effect - gradients are relatively high when $G$ is doing poorly, and low when $G$ is doing well.
Here is a very simple example of how you might implement the modified minimax loss in Tensorflow:
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
Results versus non-adversarial methods
Here is a comparison of digits generated by a conditional vanilla GANs, and a conditional variational autoencoder (roughly similar convolutional architecture and number of parameters). Both were trained on the MNIST dataset.
Notice how the samples generated by the VAE are much blurrier than those generated by the GAN.
Speculation on why GANs generate sharper and more realistic samples compared to VAEs
VAEs punish generator for producing samples that don’t match some reference exactly, while GANs do not
Consider how VAEs are commonly trained for generating samples. The loss function may include mean-squared error between the generated sample and a reference sample. The VAE gets punished when it generates a sample that doesn’t match the reference sample exactly, even if it looks valid/realistic enough. As a result, it tends to learn to output something like an average of the reference samples, which tends to be blurry.
However, a GAN does not directly punish its generator for generating a sample that does not matches some reference sample exactly. It is only punished based on how real the discriminator thinks its sample is. Additionally, the discriminator can learn to detect blurriness, associate it with fakeness, and reject those kinds of samples.
Acknowledgments
- “Generative Adversarial Nets” (Goodfellow et al., 2014).
- Ian Goodfellow’s NIPS 2016 Tutorial
- “Photo Realistic Single Image Super-Resolution Using a Generative Adversarial Network” (Ledig et al., 2017)