Intro to Diffusion Model — Part 3
In this post, we are going to develop mathematically the loss function of the diffusion model in detail and introduce the training algorithm.
This post is part of a series of posts about diffusion models:
- Intro to Diffusion Model — Part 1
- Intro to Diffusion Model — Part 2
- Intro to Diffusion Model — Part 3 (this post)
- Intro to Diffusion Model — Part 4
- Intro to Diffusion Model — Part 5
- Intro to Diffusion Model — Part 6
- Intro to Diffusion Model — Part 7
- Full implementation
In the previous part, we defined mathematically the forward and the reverse processes. We also introduced the model we want to estimate for going from a noisier image to an image with less noise.
We also said that we would simplify the problem by fixing the variance and only estimating the mean, as in the DDPM paper.
In this post, we want to build a loss function that will help us to find the parameters for the model. Basically, we can use the negative log-likelihood as a loss function −log(p_θ(x₀)). The problem with this approach is that to estimate the distribution of x₀ in the reverse process, we need to keep track of all the T steps before.
So, instead of using the negative log-likelihood directly, we can use it as a lower bound for another function that we can optimize. Let’s look at the following inequality:
We add to the negative log-likelihood a non-negative value — KL divergence — so the expression on the right-hand side is always bigger or equal to the left-hand side. By minimizing the right-hand side, we can also minimize the negative log-likelihood. You must ask yourself now how this helps us if, on the right-hand side, we again have the original negative log-likelihood. You are right, of course. To solve that, let’s play a little with the log component in the KL divergence on the right-hand side. First, we can apply Bayes' rule to the numerator of the log argument:
Taking this result back to the log, we get
By substituting the last result back into the inequality, we can get rid of the log-likelihood term and get
On the right-hand side, the denominator is the forwarding process starting from x₀, and the numerator is the reverse process, as defined in the previous post. Let's use this explicitly in the logarithm term and develop it more.
where we converted the multiplications inside the log to a summation of logs. Now, let’s move the first element under the summation outside:
For the numerator in the summation, we can use Bayes’ rule:
The problem with the terms on the left-hand side is that they all have high variance since we are in a noisy state without information on the initial state. To handle this, we can add conditioning on the initial state x₀.
Putting it back to the original equation, we get
Now, let’s focus on the summation and split it as follows:
Look carefully at the second summation on the left-hand side. The denominator is one step after the numerator. That means if we change the summation of the logs into a multiplication of the arguments in the log, the denominator of the current t will cancel the numerator of t+1. So, only the first term in the numerator and the last term in the denominator will survive.
The previous equation can now be written as
Let’s substitute this result back to the original equation and play with the logarithms:
And finally, we can represent the last result of the full loss using KL divergence notation so
The first component on the right-hand side, L_T, can be ignored while training since q has no learnable parameters, and x_T is just a Gaussian noise. Let’s look closer now to the second component, Lₜ. This component calculates the KL divergence between q and p. For p, we already explained in the previous post that it can be described by:
but we are going to use a simplified version where the variance is not learnable but is known:
Now, let’s look at q. This distribution is also a Gaussian, so we can write
Let’s recall the structure of a Gaussian distribution:
Now, we can use Bayes’ rule to write
Since all the distributions above are Gaussian, we can write, according to the development from the previous post:
The distribution of the left-hand side is Gaussian, as we wrote earlier, so we want to rearrange the right-hand side to a Gaussian structure, and then we can find the mean and the variance.
where f is some function of xₜ and x₀ that we don’t really need for getting the mean and the variance. From the last relation, we conclude that:
In the previous post, we saw that we can write
We substitute this to the previous equation and get
If we go back to the (second) loss expression we have developed, we can conclude that we want our network to train μ_θ to predict μ̃ₜ. In addition, since we know xₜ in the reverse process, we actually need to predict the noise at step t, εₜ.
Since we deal with Gaussians and the only parameter is the mean, using the KL divergence leads to the following loss:
In the DDPM paper, it is found empirically that the training works better if the scaling factor is omitted, i.e., we can simplify the loss to
That’s the objective we got from the middle term (in the summation) of the full loss expression. Now, we need to handle the last term of the loss L₀. For this term, the DDPM paper suggests using an independent discrete decoder derived from the Gaussian
The paper also assumes that the image is originally constructed from discrete pixels with values {0, 1, 2, …, 255} that scaled to the range [-1, 1]. The distribution is then represented as:
where
and D is the data dimensionality (the number of pixels in the image). Intuitively, our network predicts the mean of a pixel, and then we use this prediction to draw a Gaussian distribution. We integrate over this distribution according to the real pixel value as in the original image x₀, from the real value minus 1/255 to the real value plus 1/255. If the predicted mean is close to the original value of the pixel, the result of the integral will be high.
However, the paper simplifies this loss also. Instead of calculating the integral, it is approximated by the Gaussian probability density function times the bin width:
and the log will be
where C is a constant so that we can ignore it. We are also going to ignore the scaling factor and recall the relations:
substitute these and again ignoring the scaling factor, we get
Simplified Training Objective
From all the above development, the DDPM paper suggests using this loss function for training:
In the expression above, x₀ is the real image without any noise that we start with, t is the sample time step in the process, ε is Gaussian noise sampled at time t, and ε_θ is the neural network prediction of the noise. Our final goal is to estimate the noise at each step so we minimize the mean square error between the true noise and the predicted one.
Based on this loss function, the DDPM paper suggests the following algorithm for training:
The algorithm samples a random real image x₀ from the distribution q(x₀), then sample time step uniformly between 1 to T, which determines the noise level and generates a random sample of noise from a Gaussian distribution. Using that noise, the image is corrupted to xₜ on the basis of x₀ and a known βₜ, and the neural network uses that in its training to predict the noise.