We recommend that you use Google Colab, as training will be faster on the GPU.
To enable the GPU on Colab, go to Edit / Notebook settings / Hardware accelerator / select T4 GPU
Instructions on how to download and use Jupyter Notebooks can be found here. You can find a static version of the notebook below.
0 Environmental Set up¶
# Download the cleanfid package for FID score calculation
!pip install git+https://github.com/GaParmar/clean-fid.git
import torch
import os
import pathlib
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from diffusers import UNet2DModel
from tqdm import trange
from torch.utils.data import Subset
from torchvision.utils import save_image
from scipy.stats import multivariate_normal
from PIL import Image
from cleanfid import fid
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
1 Data Probability Distributions¶
In several of the past chapters we have discussed the difference between the data on which a model is trained and the data on which a model is executed. We have emphasized that these two sets are different but we have not discussed much how they are related. Their relationship is that both of them are sampled from a common probability distribution.
A probability distribution $p(x)$ is a function that assigns values to the likelihood of observing different possible outcomes $x$. The likelihood is large for signals that are likely and small for signals that are unlikely. As an example, consider a bag that contains all of the possible digits that have, are, or will ever be written. The first image in Figure 1 has a high likelihood as it is a fairly typical handwritten number five. The second image has a smaller likelihood. It is a discernible number five, but not a typical one. The third image has negligible likelihood, maybe even zero likelihood, because it does not look like a digit.
It is convened that likelihoods are normalized so that they integrate to 1, $$ \int p(x) \, dx = 1 .\tag{1} $$ The meaning of this normalization is that when we put our hand in the bag we extract one and only one digit. Thus, the sum (integral) of all likelihoods must be exactly one.
When we train an AI model it is implicit that training samples are extracted from an underlying probability distribution $p(x)$ which is the same distribution from which future samples will be extracted. In digit classification we pull handwritten digits from the bag of all digits that have, are, or will ever be written. We use this digits to train a CNN that we will then execute on any digits that are extracted from this same bag.
In this chapter we are after a more ambitious AI task than in previous chapters. We want to train an AI system that replicates the probability distribution itself. We are not just after classifying digits pulled from the bag of all digits that have, are, or will ever be written. We want to create an artificial bag from where we can pull digits with a distribution of likelihoods equivalent to the distribution of likelihoods of the natural system.
This is the most ambitious AI task we can conceive. If we succeed, we can prescind of reality altogether. Tackling this task requires that we understand distributions a little better and that we tackle the difference between statistical and empirical risk. We do that in the following before discussing generative models in Section 7.
2 Gaussian Distributions¶
A Gaussian probability distribution with mean $\mu$ and variance $\sigma^2$ is a probability distribution over scalar values $x$; see Figure 2. This probability distribution has the formula $$ p(x) = \frac{1}{\sqrt{2\pi}\sigma} \exp \left[ -\frac{(x-\mu^2)}{2\sigma^2} \right] .\tag{2} $$ All Gaussian distributions look like crosssections of a bell. The mean $\mu$ determines the center of the bell and the variance $\sigma^2$ the concentration of the bell around its mean. Smaller variances imply larger concentrations. Gaussian distributions are also called normal distributions.
In Figure 2 the probability distribution $p_1(x)$ is normal and has mean $\mu=0$ and variance $\sigma^2=1$. We see in the figure that values around $x=\mu=0$ are the most likely to be observed, that values around $x$ and $-x$ are equally likely, and that values much larger that $x= 3\sigma = 3$ are unlikely. Distribution $p_2(x)$ in Figure 2 has mean $\mu=4$ and variance $\sigma^2=1$. It has the same shape of distribution $p_1(x)$ but is shifted to the right. This is because they have the same variance while the mean is of $p_2(x)$ is higher. Distribution $p_3(x)$ has mean $\mu=8$ and variance $\sigma^2=1/4$. In addition to been further shifted to the right, distribution $p_3(x)$ is more concentrated around its mean. This is because its variance is smaller than the variance of $p_1(x)$ and $p_2(x)$.
2.1 Multivariate Gaussian Distributions¶
Our interest extends to probability distributions across vectors $x\in\mathbb{R}^n$. In this case we define a vector of means $\mu \in \mathbb{R} ^n$ and a covariance matrix $C\in\mathbb{R}^{n\times n}$. A multivariate Gaussian probability distribution assigns likelihoods to samples according to the formula, $$ p(x) = \frac{1}{(2\pi)^{n/2} \det^{1/2}(C)}~ \exp \left[ -\frac{1}{2} (x-\mu)^T C (x-\mu) \right] = \mathcal{N}\Big(\ x; \mu, C \,\Big).\tag{3} $$ In this formula $\det^{1/2}(C)$ is the square root of the determinant of the covariance matrix and the product $(x-\mu)^T C (x-\mu)$ evaluates to a number that we exponentiate. The matrix $C=C^T$ is symmetric.
However complicated, the formula for the normal distribution is just a formula. To save the time of writing (3), we use the shorthand $$ p(x) = \mathcal{N} \Big(\, x;\, \mu,\, C \,\Big), \tag{4} $$ to denote a normal distribution with mean $\mu$ and covariance matrix $C$ evaluated at $x$.
Multivariate Gaussian distributions have bell shapes. Or what we imagine a bell is in high dimensions. What we can visualize is that one and two dimensional crosssections of a multivariate Gaussian have bell shapes. In any case, the mean $\mu$ of a multivariate normal distribution represents the vector $x$ with the highest likelihood. The covariance matrix quantifies the concentration of the distribution around its mean. Since we are in multiple dimensions the concentration can be different in different directions.
An important particular case is when the covariance matrix $C = \sigma^2 I$ is a scaled identity. In this case the determinant is $\det(C)=\sigma^{2n}$ and the product in the exponent is $(x-\mu)^T C (x-\mu)= \|x-\mu\|^2$. These substitutions yield the probability distribution $$ p(x) ~=~ \frac{1}{ (2\pi\sigma^2)^{n/2} }~ \exp \left[\, – \frac{1}{2\sigma^2} \|x-\mu\|^2 \,\right] ~=~ \mathcal{N}\Big(\, x;\, \mu,\, \sigma^2I \,\Big) .\tag{5} $$ This is a symmetric distribution. The likelihood of observing a vector $x$ depends on its distance $\|x-\mu\|$ to the mean only. It is the same in any direction. We call this distribution white and we say that the scalar $\sigma^2$ is its variance. In the particular case when the variance is $\sigma^2 = 1$ and the mean is $\mu=\bf{0}$ we further say that the distribution is standard. We show a standard white distribution in Figure 3 when $n=2$.
Distributions that are not white are called colored. Figure 5 depicts colored Gaussian distributions in two dimensions. All of the distributions have zero mean but different covariance matrices. Since we are in two dimensions, the covariance matrix $C$ has 2 rows and 2 columns. We write it explicitly as $$ C = \left( \begin{array}{ll} \sigma_{11}^2 & \sigma_{12} \\ \sigma_{12} & \sigma_{22}^2 \end{array} \right) \tag{6} $$ The first distribution in Figure 5 is more spread along the horizontal axis and more concentrated along the vertical axis. It corresponds to a covariance matrix with $\sigma_{11}^2 = 4$, $\sigma_{22}^2 = 36$ and $\sigma_{12} = 0$. The second distribution in Figure 5 flips the axes. It is more spread along the vertical axis and more concentrated along the horizontal axis. It corresponds to a covariance matrix with $\sigma_{11}^2 = 36$, $\sigma_{22}^2 = 4$ and $\sigma_{12} = 0$. The third distribution is a rotated version of the previous two. It is spread out along the ascending diagonal and concentrated in the descending diagonal. It corresponds to $\sigma_{11}^2 = 20$, $\sigma_{22}^2 = 20$ and $\sigma_{12} = 16$.
Task 1 Generate $N = 10^3$ samples from a standard multivariate white normal distribution in dimension $n=2$. Each of these samples is a vector with two entries. Plot these samples on the plane and compare their density to their probability distribution $p(x)$ [cf. (5) with $n=2$, $\bf{\mu}=\bf{0}$ and $\sigma^2=1$]. You should observe that samples accumulate in places where the likelihood $p(x)$ is large.
N = 1000 # Number of samples
mu = np.zeros(2) # mu of sample distribution
C = np.eye(2) # C of sample distribution
samples = np.random.multivariate_normal(mu, C, N) # sample distribution with parameters
# Calculate probability density for visualization
x, y = np.mgrid[-3:3:.01, -3:3:.01]
pos = np.dstack((x, y))
pdf = multivariate_normal(mu, C).pdf(pos)
# Visualize samples vs. probability density
plt.contourf(x, y, pdf, cmap="Oranges")
plt.colorbar(label="Probability Density")
plt.contour(x, y, pdf, cmap="Reds")
plt.scatter(samples[:, 0], samples[:, 1], color = 'teal', marker = '.', label="Samples", alpha = 0.5)
plt.xlabel("X_1")
plt.ylabel("X_2")
plt.legend()
plt.title("Samples vs. Probability Density")
plt.show()
3 Expectations and Statistical Risks¶
An expectation is an average weighted by likelihoods. Suppose that we are working with data $x$ which has a probability distribution $p(x)$. For each data point $x$ we evaluate a function $f(x)$. The expected value of this function over the data distribution is the average $$ \mathbb{E}_{x\sim p}(f(x)) = \int f(x) p(x) \, dx .\tag{7} $$ In this expression, function values $f(x)$ are weighted by the likelihood $p(x)$ of observing $x$. The expectation is an integral over these weighted function values.
The expectation in (7) is to be contrasted with the sample mean. This is an average over a set of data points drawn from the distribution $p(x)$. Formally, consider $N$ points $x_n$ drawn from the bag described by $p(x)$. The sample mean of this set of points is the simple average $$ \bar{f} = \frac{1}{N}\sum_{n=1}^{N} f(x_n) .\tag{8} $$ The most important result in the theory of probability, some would say the only important result, is the proximity between the expectation $\mathbb{E}(f(x))$ and the sample mean $\bar{f}$ when the points $x_n$ are drawn independently. This fact is called the law of large numbers.
The law of large numbers is remarkable because expectations and sample means are two quantities of fundamentally different natures. Whereas the expectation is a property of the probability distribution, the sample mean is a property of the samples. Different sets of data have different sample means. The law of large numbers says that all of these sample means are close to the expectation and, by extension, close to each other.
Ultimately, this is the reason why training on empirical risks is justifiable. When we train a model on the empirical risk, $$ w_{\text{ERM}}^* ~=~ \arg\min_w ~\frac{1}{N}\sum_{i=1}^{N}~ c \Big(\, \Phi(x_i, w) \,\Big),\tag{9} $$ the hope is that the solution of empirical risk minimization (ERM) is close to the solution of the statistical risk minimization (SRM) problem $$ w_{\text{SRM}}^* ~=~ \arg\min_w ~\mathbb{E}_{x\sim p} \Big[\, c \Big(\, \Phi(x, w) \,\Big) \, \Big].\tag{10} $$ Again, (9) and (10) are fundamentally different problems. If we think of (9) and (10) as models of digit classification problems, the statistical risk is an expectation (an average) over the bag that contains all of the digits that have, are, or will ever be written. The empirical risk is an average over a set of digits pulled out of this bag. The SRM solution $w_{\text{SRM}}^*$ is a property of the bag. The ERM solution $w_{\text{ERM}}^*$ is a property of the specific set of data samples that are pulled from the bag.
The fundamental result of statistical learning theory is that the solutions of ERM and SRM problems are close. In particular, this result implies that training and execution performance are similar. Both are close to SRM performance. That SRM and ERM is not a ready consequence of the law of large numbers. It is a rather involved consequence which requires restrictions on the choice of the learning parameterization.
It is important to realize that the SRM problem in (10) is the problem that we would like to solve whereas the ERM problem in (9) is the problem that we can solve. Solving SRM requires knowing the probability distribution $p(x)$ which is in general unknown. In Figure 1 we can ascertain with confidence that likelihoods decrease as we move from left to right. However, we cannot ascertain with any confidence the exact likelihood of drawing any of these samples. What we can do is pull samples from the bag of the digits that have, are, or will ever be written, This is ERM.
4 Generative Models¶
Given a probability distribution $q(x)$ we want to learn a probability distribution $p(x, w) $ that imitates $q(x)$. This goal can be subsumed within the definition of artificial intelligence (AI) as the imitation of a natural system. The distribution $q(x)$ is a natural system such as the bag of all digits. We want to learn a distribution $p(x, w) $ that generates digits according to the same likelihood distribution. We call $p(x, w) $ a generative model because, once trained, we can use it to generate (artificial) data that is indistinguishable from the (natural) data that we could sample from $q(x)$ — modulo the accuracy of the training model.
Designing this generative AI requires the selection of a parameterization and the selection of a loss. We discuss in this section the choice of loss.
4.1 Kullback-Leibler (KL) divergence¶
We compare distributions $q(x)$ and $p(x, w) $ with the Kullback-Leibler (KL) divergence which we define as $$ D_{\text{KL}}\Big(q(x) \| p (x, w)\Big) = \mathbb{E}_{x \sim q } \log \left[ \frac{q(x)}{p(x, w) } \right] = \int q(x) \log \left[ \frac{q(x)}{p(x, w) } \right] dx .\tag{11} $$ The ratio $p(x, w) /q(x)$ is a relative comparison of the similarity between the likelihoods $p(x, w) $ and $q(x)$. This comparison is passed through a logarithm and averaged over the distribution $q(x)$.
To use the KL divergence as a loss we need to make sure that it is indeed a loss. We show in the following proposition that this is true.
Proposition 1
The KL divergence between distributions $q(x)$ and $p(x, w) $ is nonnegative,
$$
D_{\text{KL}}\Big(q(x) \| p (x, w)\Big) \geq 0 \tag{12}
$$
with equality attained when $q(x) = p(x, w) $ for all $x$.
Proof: To show that the KL divergence is nonzero rewrite its definition in (11) as $$
- \int q(\mathbf{x}) \log \left( \frac{p(x, w)}{q(x)} \right) dx
\geq – \int q(x) \left( \frac{p(x, w)}{q(x)} – 1 \right) dx \tag{13} $$
This fact holds because we reversed the ratio $q(x)/p(x, w) $ inside the logarithm but compensated by adding a negative sign at the front.
To proceed, remember that the logarithm function satisfies $\log(\alpha) \leq \alpha -1$ for any $\alpha>0$. Using this fact in (13) we see that $$ – \int q(x) \log \left[ \frac{p(x, w) }{q(x)} \right] dx \geq ~ – \int q(x) \left[ \frac{p(x, w) }{q(x)} – 1 \right] dx \tag{14} $$ Notice now that in the right hand side we can simplify terms to obtain $$ – \int q(x) \left[ \frac{p(x, w) }{q(x)} – 1 \right] dx = – \int p(x, w) + \int q(x) = 0. \tag{15} $$ The second equality is true because $p(x, w) $ and $q(x)$ are probability distributions that integrate to 1 [cf. (1)].
To see that equality holds when $q(x) = p(x, w) $ note that if this is true the ratio $p(x, w) /q(x) = 1$ and the logarithm of this ratio equals zero. Integrating zeroes over all $x$ is still a zero.
Proposition 1 shows that the KL divergence is a valid loss for the problem of imitating the natural distribution $q(x)$ with the distribution $p (x, w)$. The KL divergence is minimized at $D_{\text{KL}}(q(x) \| p (x, w)) = 0$ when $q(x) = p (x, w)$.
The KL divergence is more than a valid loss to compare distributions $q(x)$ and $p (x, w)$. It is a divergence. The meaning of this distinction is not relevant to us, but it is worth pointing out. A notable observation is that KL divergences are not necessarily symmetric, $$ D_{\text{KL}}\Big(q(x) \| p (x, w)\Big) \not \equiv D_{\text{KL}}\Big(p (x, w) \| q(x)\Big) \tag{16} $$ Comparing $p (x, w)$ to $q(x)$ with KL divregences is not the same as comparing $q(x)$ to $p(x,w)$. The choice we make in (17) of comparing $p (x, w)$ to $q(x)$ is not arbitrary. Its motivation will become clear in the next section.
4.2 Risk Minimization for Generative Models¶
Having established that the KL divergence is minimized when $q(x) = p (x, w)$, we formulate the learning of a generative model as the statistical risk minimization problem $$ w_{\text{SRM}}^* ~=~ \arg\min_w D_{\text{KL}}\Big(q(x) \| p (x, w)\Big) ~=~ \arg\min_w \mathbb{E}_{x \sim q } \log \left[ \frac{q(x)}{p(x, w) } \right]. \tag{17} $$ As with any other SRM problem, (17) is a problem that we would like to solve but one that we cannot solve because we do not know the probability distribution $q(x)$. We can, however, resort to the acquisition of data, the formation of a training set, and the formulation of the empirical risk minimization problem, $$ w^* ~=~ \arg\min_w \frac{1}{N} \sum_{n=1}^{N} \log \left[ \frac{q(x_n)}{p(x_n, w)} \right]. \tag{18} $$ This ERM formulation is not workable because we do not know $q(x_n)$. When we pull the samples in Figure 1 from the bag of digits, we do not know how likely they are. If we did know $q(x_n)$, we would be solving the SRM problem in (17).
This issue has a simple solution. Return to the definition of the KL divergence and split the logarithm of the ratio to write $$ \mathbb{E}_{x \sim q} \log \left[ \frac{q(x)}{p(x, w) } \right] = \mathbb{E}_{x \sim q } \Big[\, \log q(x) – \log p(x, w) \,\Big] .\tag{19} $$ In (19) we sum the logarithm of the natural distribution $\log q(x)$ and the artificial distribution $\log p(x, w) $. Since the logarithm of the natural distribution $\log q(x)$ is independent of the choice of parameter $w$, its presence in the SRM problem is moot. The value of $\mathbb{E}_{x \sim q }\log q(x)$ is the same irrespectively of our choice of the parameter $w$. We have thus shown that the SRM problem in (17) is equivalent to $$ w_{\text{SRM}}^* ~=~ \arg\min_w – \mathbb{E}_{x \sim q } \Big[\, \log p(x, w) \,\Big ] .\tag{20} $$ This is an SRM problem with loss $-\log p(x, w) $. This loss is unusual. It can be negative and it is not directly comparing $p(x, w) $ to $p(x)$. It is just a function of the probability $p(x, w) $. Its use is nevertheless justified because (20) is equivalent to (17) and we have seen that the KL divergence in (17) is a valid loss. The distribution that minimizes (20) is $p(x)$ and this all that matters in the end.
Since the statistical problems in (17) and (20) are equivalent we can write an empirical version of (20) that will be equivalent to the ERM in (18). This problems takes the form $$ w^* ~=~ \arg\min_w – \frac{1}{N} \sum_{n=1}^{N} \log p(x_n, w). \tag{21} $$ In this ERM problem the distribution $q(x_n)$ is not present. We can solve it by pulling samples from the bag of all digits without having to know the likelihood of the samples. Take a moment to marvel at (21). We are solving a very difficult problem — the imitation of a probability distribution from a set of samples — with the minimization of a very simple objective — a sum of logarithms of the estimated likelihoods of each sample.
5 Generative Normal Distribution Models¶
In (21), the distribution $p(x, w)$ is a family of distributions parameterized by $w$. We choose here to work with white normal distributions with covariance $C = I$ and variable mean $\mu$, $$ p(x, w) = p(x, \mu) = \frac{1}{(2\pi)^{n/2}}~ \exp \left[ -\frac{1}{2} \|x-\mu\|^2 \right] .\tag{22} $$ We are now given a set of points $x_n$ drawn from an unknown distribution $q(x)$ and are asked to find the distribution $ p(x, w) = p(x, \mu)$ within the class defined in (22) that is closest to $q(x)$. This goal can be attained by solving the ERM problem in (21) with $p(x, w) = p(x, \mu)$ as given in (22). After eliminating unnecessary constants this substitution yields the ERM problem $$ \mu^* ~=~ \arg\min_{\mu} – \frac{1}{2N} \sum_{n=1}^{N} \|x_n-\mu\|^2 .\tag{23} $$ We say that the probability distribution $p(x, \mu^*)$ is a generative normal distribution model (GNDM). It is a generative model because we can use it to generate data from the same distribution that generated the samples $x_n$. It is generative normal distribution model because it is fitting a normal distribution to the data samples $x_n$.
GNDM training is a very simple optimization problem to solve. It’s simplicity stems from the fact that the Gaussian family of distributions in (22) is not very rich. We are using this simple parameterization because it is a preliminary approach to the diffusion models of Section 6.
Task 2 Let $q(x)$ be a 2D normal distribution with covariance $C=I$ and mean $\mu = [1.0; 2.0]$. Generate $m = 10^3$ samples $x_n$ of the distribution $q(x)$. Solve (23) for this dataset. Compare the learned distribution $p(x, \mu^*)$ with the dataset. They should match closely.
N = 2000 # Number of samples
mu_true = np.array([1.0, 2.0]) # mu of sample distribution
C = np.eye(2) # C of sample dsitrbution
original_samples = torch.tensor(np.random.multivariate_normal(mu_true, C, N)) # Sample distribution with defined parameters
mu = torch.randn(2, requires_grad=True) # Initialize mu randomly for learning
optimizer = optim.SGD([mu], lr=0.1)
epochs = 500
for epoch in range(epochs):
optimizer.zero_grad()
loss = torch.mean((original_samples - mu) ** 2) # Mean squared error
loss.backward()
optimizer.step()
mu_star = mu.detach().numpy()
print("Learned mu_star =", mu_star)
print("Original mu =", mu_true)
learned_samples = np.random.multivariate_normal(mu_star, C, N) # Generate samples with learned mu
# Calculate probability density for visualization
x, y = np.mgrid[-3:6:.01, -3:6:.01]
pos = np.dstack((x, y))
pdf = multivariate_normal(mu_true, C).pdf(pos)
# Visualize samples generated from learned distribution (mu) vs. original probability density
plt.contourf(x, y, pdf, cmap="Oranges")
plt.colorbar(label="Probability Density")
plt.contour(x, y, pdf, cmap="Reds")
plt.scatter(learned_samples[:, 0], learned_samples[:, 1], label="Learned samples", marker = '.', color = 'teal', alpha = 0.2)
plt.xlabel("X_1")
plt.ylabel("X_2")
plt.legend()
plt.title("Samples generated from Learned Distribution vs. Original Probability Density")
plt.show()
Learned mu_star = [1.0628139 1.9944205] Original mu = [1. 2.]
6 Generative Diffusion Models¶
In Task 2 we succeed at learning a Gaussian probability distribution $p(x, \mu^*)$ that matches samples $x_n$ of an unknown Gaussian probability distribution $q(x)$. In practice, we want to learn distributions $q(x)$ that are more complex than Gaussian distributions. Say, the probability distribution of all the digits that have, are, or will ever be written. To learn these distributions we need a complex parametric family of distributions $p(x,w)$ to match to our data. Generative diffusion models (GDMs) are such a class. They are based on the parameterization of a backward diffusion process (Section 6.2) whose introduction requires that we first explain forward diffusion processes (Section 6.1)
6.1 Forward Diffusion Process¶
A forward diffusion process is defined by the recursive mixing of data with samples from a standard white normal distribution. Consider then data samples $x = x_0 \in \mathbb{R}^n$ drawn from a data distribution $q(x_0) = q(x)$,
$$
x = x_0 \sim q(x_0) = q(x) . \tag{24}
$$
Introduce now a time index $t=1,\ldots,T$ and associate samples $\epsilon_t \in \mathbb{R}^n$ from a standard white normal distribution $\mathcal{N}(\bf{\epsilon_t}, \bf{0}, I)$ with each time index $t$ — the standard white distribution is given by (5) with $\mu=\bf{0}$ and $\sigma^2=1$. Further consider a sequence of scalar coefficients $\alpha_t<1$ and define the sequence
$$
x_t
= \Big(\sqrt{\alpha_t}\Big) \times x_{t – 1}
+ \Big(\sqrt{1 – \alpha_t}\Big) \times \epsilon_t . \tag{25}
$$
When adding the sample $\epsilon_t$ to $x_{t – 1}$ we say that we are adding noise. This is because as we can see in Figure 6, samples from white normal distributions look like noise. Although not required, we choose constants $\alpha_t\approx 1$ in practice. This means that at each step we are adding a small amount of noise.
The idea of the recursion in (25) is to add noise progressively. In the first application of {25} we go from the input data sample $x_0$, which does not have any noise added, to sample $x_1$. This sample has some amount of noise. We then proceed to add some more noise to create $x_2$ and even more noise to create $x_3$. Observe that since $\alpha_1<1$ the noise is becoming more prominent while the input data $x_0$ is being washed out. The goal is that when we get to time $T$, the signal $x_T$ is almost the same as a sample from a white normal distribution. That we end up with pure noise is key to construct the backward process in Section 6.3 and its learned version in Section 6.3.
It is easy to miss that (25) not only specifies how the data sample $x_0$ generates the sequence of samples $x_t$ but also how the data distribution $q(x_0)$ generates a sequence of distributions $q(x_t)$. The latter describes the likelihood of observing $x_t$ at time $t$ and depends on a combination of the likelihoods $q(x_0)$ of drawing different data samples $x_0$ and the likelihoods $\mathcal{N}(\bf{\epsilon_t}, \bf{0}, I)$ of drawing different noise samples $\epsilon_u$ at times $u \leq t$.
The distribution $q(x_t)$ is complex. An easier distribution to characterize is the distribution at time $t$ assuming that we know the value of the diffusion process at time $t-1$. We denote this distribution as $q(x_t|x_{t – 1})$ and call it the conditional distribution of $x_t$ given $x_{t-1}$. If we know $x_{t-1}$ in (25) the only random quantity is the noise $\epsilon_t$. Thus, this conditional distribution is normal with mean $\mu = (\sqrt{\alpha_t}) \times x_{t – 1}$ and covariance matrix $C = (1-\alpha_t)I$, $$ q(x_t|x_{t – 1}) = \mathcal{N}\Big(\, x_t; \, (\sqrt{\alpha_t})\times x_{t – 1}, \, (1 – \alpha_t)\textbf{I} \,\Big) \tag{26} $$ Equations (25) and (26) are not quite the same. Equation (25) describes one path of the diffusion process. Equation (26) describes the evolution of the probability distribution of the diffusion process. Since we can go from (25) to (26) and vice versa, this discussion can be a little pedantic. Still, it is important to understand that the diffusion process [cf. (26)] and a path of the diffusion process [cf. (25)] are different. Like a handwritten digit is different from the probability distribution of handwritten digits.
It is relevant to forthcoming derivations that the addition of Gaussian distributions is also a Gaussian distribution. In particular, it is known that the recursion in (25) can be unrolled between times $0$ and $t$ to write $$ x_t = \Big(\sqrt{\bar\alpha_t}\Big) \times x_{0} + \Big(\sqrt{1 – \bar\alpha_t}\Big) \times \bar\epsilon_t , \tag{27} $$ where the constant $\bar{\alpha} = \alpha_1\times\ldots\times\alpha_t$ is the product of the $\alpha_t$ constants in (25) and $\bar\epsilon_t \in \mathbb{R}^n$ is sampled from a standard white normal distribution $\mathcal{N}(\bf{\bar\epsilon_t}, \bf{0}, I)$. Equivalently, the expression in (26) can be also be unrolled between times $0$ and $t$ to write $$ q(x_t|x_{0}) = \mathcal{N}\Big(\, x_t; \, (\sqrt{\bar\alpha_t}) \times x_{0}, \, (1 – \bar\alpha_t)\textbf{I} \,\Big). \tag{28} $$ The expressions in (27) and (28) let us sample a realization of $x_t$ directly from $x_0$ without the burden of computing all of the intermediate steps $x_u$ for $u<t$.
6.2 Backward Diffusion Process¶
Now comes the hard part: How can we start from a pure noise sample $x_T$ and iteratively denoise the image until we recover an original sample $x_0$? Essentially, our goal is to reverse the Forward process. We call this reversed process the backward process.
Recall that each step in the Forward process is essentially sampling from a Gaussian distribution. Luckily, it turns out that if we knew the original sample $x_0$ that we started from, then each denoising step in the backward process would also be just sampling from a Gaussian, $$ q(x_{t – 1} \vert x_t, x_0) = \mathcal{N}(x_{t – 1}; \mu_q (x_t, x_0), \sigma_q(t)) \tag{29} $$ This is due to properties of the Gaussian Distribution. The variance of this Gaussian only depends on our noise schedule and is given by: \begin{equation} \sigma_q^2(t) = \frac{(1 – \alpha_t)(1 – \bar \alpha_{t – 1})}{1 – \bar \alpha_t} \tag{30} \end{equation} The mean of this Gaussian, $\mu_q (x_t, x_0)$, is essentially the likeliest $x_{t – 1}$ from which $x_t$ could have been generated and importantly, it depends on $x_0$ as well as $x_t$.
\begin{equation} \mu_q(x_t, x_0) = \frac{\sqrt{\alpha_t}(1 – \bar{\alpha}_{t-1})x_t + \sqrt{\bar{\alpha}_{t-1}}{(1 – \alpha_t)}x_0}{1 – \bar{\alpha}_t} \tag{31} \end{equation}In practice, it is usually easier to work with the added noise $\epsilon_0$ rather than the original image $x_0$. Thus from (27) we can write $x_0$ in terms of $\epsilon_0$ and $x_t$ to get: $$ \mu_q(x_t, \epsilon_0) = \frac{1}{\sqrt{\alpha_t}} x_t – \frac{1 – \alpha_t}{\sqrt{1 – \bar{\alpha}_t} \sqrt{\alpha_t}}\epsilon_0 \tag{32} $$ Recall that our goal in generative modeling was to learn a distribution $p(x, w)$ that imitates $q(x)$. In diffusion models, we have established so far that to get a sample from $q(x)$, we can start from a pure noise sample $x_T$, and sample from conditional Gaussian distributions $q (x_{t – 1}| x_t, x_0) = \mathcal{N}(x_{t – 1}; \mu_q (x_t, \epsilon_0), \sigma_q(t))$ at each time step from $t = T, T – 1, \cdots , 1$, until we reach $x_0$. Let’s call this the “True Backward process”. To imitate this process, we will define our model $p(x, w)$ in the same way i.e. to sample from $p(x, w)$, we start from a pure noise sample $x_T$, and sample from conditional Gaussian distributions $p(x_{t – 1} | x_t, w)$ at each time step until we reach $x_0$. We call this the learned backward process and discuss it in the next section.
6.3 Learned Backward Diffusion Process¶
Our goal is for the distribution of samples generated by the learned Backward process, $p(x_0, w)$, to match the distribution of samples generated by the “true Backward process”, $q(x_0)$, which is in fact the underlying data distribution. The bad news is that unlike the generative normal distribution in part 3 where we could evaluate $p(x, w)$ for any given $x, w$, for a diffusion model, evaluating $p(x_0, w)$ is intractable (both in theory and practice). However, the good news is that if the intermediate denoising steps of the true and learned backward processes match closely, meaning if the Gaussian distributions $q (x_{t – 1}| x_0, x_t)$ and $p(x_{t – 1} | x_t, w)$ are close for all $t = 1, \cdots T$, then the final sample distributions $q(x_0)$ and $p(x_0, w)$ will also be close. Therefore we will choose our loss function in a way to ensure these distributions are close in every step. Thus in a similar manner to~\eqref{eqn_generative_srm} but over all time steps we write: $$ w_{\text{SRM}}^* ~=~ \arg\min_w \mathbb{E}_{x_0} \left[ \sum_{t = 2}^T \mathbb{E}_{x_t \sim q (x_t | x_0)} \left[D_{\text{KL}}\Big(q (x_{t – 1}| x_0, x_t) \| p(x_{t – 1} | x_t, w)\Big) \right] \right] \tag{33} $$
We will now focus on simplifying the loss in (33). Recall that $q (x_{t – 1}| x_0, x_t)$ was a Gaussian distribution with mean $\mu_q(x_t, x_0)$ and variance $\sigma_q(t)$. To imitate it as closely as possible, it makes sense to choose $p(x_{t – 1} | x_t, w)$ to also be a Gaussian with similar mean and variance:
\begin{equation} q(x_{t – 1} \vert x_t, x_0) = \mathcal{N}(x_{t – 1}; \mu_q (x_t, \epsilon_0), \sigma_q(t)) \tag{34} \end{equation}\begin{equation} p(x_{t – 1} | x_t, w) = \mathcal{N}(x_{t – 1}; \mu_p (x_t, w), \sigma_q(t)) \tag{35} \end{equation}Note that the variance has been chosen to be exactly the same. We can do this because $\sigma_q(t)$ only depends on the time step and a fixed noise schedule, both of which we have access to in the learned backward process. With this smart choice of $p(x_{t – 1} | x_t, w)$, we can reduce the loss in (33) to something more familiar:
\begin{equation}
w_{\text{SRM}}^* ~=~ \arg\min_w \mathbb{E}_{x_0} \left[ \sum_{t = 2}^T \mathbb{E}_{x_t \sim q (x_t | x_0)} \left[ \|\ \mu_p(x_t, w)- \mu_q(x_t, \epsilon_0)\|^2 \right] \right] \tag{36}
\end{equation}
This should remind us of the loss in (23). Our goal has now been simplified to matching $\mu_q(x_t, \epsilon_0)$ as closely as possible without knowing $\epsilon_0$. Again we can choose our model smartly based on our knowledge of $\mu_q(x_t, \epsilon_0)$.
\begin{equation}
\begin{array}{rcl}
\displaystyle \mu_q(x_t, \epsilon_0) & = & \displaystyle \frac{1}{\sqrt{\alpha_t}} x_t – \frac{1 – \alpha_t}{\sqrt{1 – \bar{\alpha}_t} \sqrt{\alpha_t}} \cdot \epsilon_0
\tag{37}
\displaystyle \mu_p(x_t, w) & = & \displaystyle \frac{1}{\sqrt{\alpha_t}} x_t – \frac{1 – \alpha_t}{\sqrt{1 – \bar{\alpha}_t} \sqrt{\alpha_t}} \cdot \hat\epsilon (x_t, t, w)
\end{array}
\end{equation}
With this choice of $\mu_p(x_t, w)$ the loss in (36) turns into:
\begin{equation}
w_{\text{SRM}}^* ~\simeq~ \arg\min_w \mathbb{E}_{x_0} \left[ \sum_{t = 2}^T \mathbb{E}_{x_t \sim q (x_t | x_0)} \left[ \|\ \hat\epsilon (x_t, t, w) – \epsilon_0\|^2 \right] \right] \tag{38}
\end{equation}
We have ignored the time dependent weighting terms $\frac{1 – \alpha_t}{\sqrt{1 – \bar{\alpha}_t} \sqrt{\alpha_t}}$ as including them doesn’t result in any performance gains. At the end of the day, to train our diffusion model, we have to train a model $\hat\epsilon (x_t, t, w)$ to predict the true noise $\epsilon_0$, given the current noisy sample $x_t$ and time step $t$. The ERM version of the SRM problem in (38) is:
\begin{equation}
w_{\text{ERM}}^* ~=~ \arg\min_w \sum_{i = 1}^N \sum_{t = 2}^T \sum_{j = 1}^M \|\ \hat\epsilon (x_t^{(i, j)}, t, w) – \epsilon_0^{(i, j)}\|^2 \tag{39}
\end{equation}
where $\epsilon_0^{(i, j)} \sim \mathcal{N}(\epsilon_0^{(i, j)}; \textbf{0},\textbf{I})$, and $x_t^{(i, j)} = \sqrt{\bar{\alpha}_t} \cdot x_0^{(i)} + \sqrt{1 – \bar{\alpha}_t} \cdot \epsilon_0^{(i, j)}$.
In practice during training, for a batch of samples, e.g. $x_0^{(1)}, \cdots, x_0^{(B)}$ , we randomly sample $B$ time steps $t^{(1)}, \cdots t^{(B)}$ uniformly from $2$ to $T$, and also sample $B$ points $\epsilon_0^{(1)}, \cdots, \epsilon_0^{(B)}$ from a standard Gaussian. The batch loss would then be: \begin{equation} \sum_{i = 1}^B \|\ \hat\epsilon (x_t^{(i)}, t^{(i)}, w) – \epsilon_0^{(i)}\|^2 \tag{40} \end{equation}
Depending on the type of data that we are generating, we could choose different models as our noise predictor $\hat\epsilon (x_t, t, w)$. For generating images we will use a U-Net architecture as our noise predictor as discussed in the next section.
7 Image Generative Diffusion Models¶
In the remainder of the this lab, we will train a Diffusion model to learn to generate images of Handwritten digits by training it on the MNIST dataset. We start by loading the dataset.
Task 3 Load the MNIST dataset. Since training on the entire dataset will take very long, we will be training on a subset of it with $m = 3000$ samples. Your training set should have a roughly equal number of samples from each digit class. You should also normalize each image and resize it to be $32 \times 32$ pixels.
# Get access to your Google Drive
from google.colab import drive
drive.mount('/content/drive')
# Download the MNIST dataset
#transform resizes and normalizes the images in our dataset
transform = transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
batch_size = 64
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_data = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
# Adjust subset size here, so you can reduce the training time
subset_size = int(0.05 * len(train_data))
indices = torch.randperm(len(train_data))[:subset_size]
train = Subset(train_data, indices)
subset_size = int(0.1 * len(test_data))
indices = torch.randperm(len(test_data))[:subset_size]
test = Subset(test_data, indices)
train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True)
# Save samples in a folder named "MNIST_samples" in your Google Drive
mnist_save_dir = '/content/drive/MyDrive/MNIST_samples'
os.makedirs(mnist_save_dir, exist_ok=True)
num_samples = 10000 # Number of samples to save
for i in range(num_samples):
image, _ = train_data[i]
save_image(image, os.path.join(mnist_save_dir, f"sample_train_{i}.png"))
# Plot a few samples
num_samples = 2
sample_train = list(enumerate(train_loader))[0][1][0][:num_samples]
sample_test = list(enumerate(test_loader))[0][1][0][:num_samples]
fig, axs = plt.subplots(2, num_samples)
for i, sample in enumerate(sample_train):
axs[0, i].imshow(sample.numpy().squeeze(), cmap='gray')
for i, sample in enumerate(sample_test):
axs[1, i].imshow(sample.numpy().squeeze(), cmap='gray')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
We now write a function to implement the Forward process.
Task 4 Write a function that takes a batch of images, and a list of timesteps, as inputs. The output is a batch of noisy versions of the input images according to the given timesteps. (See Figure 6). For the noise schedule, $\alpha_t$ decreases linearly across timesteps from $\alpha_0 = 0.9999$ to $\alpha_T = 0.98$. Choose the total number of diffusion time steps as $T = 500$. Show an example batch with a single image. Plot the noisy images at time steps $t = 100,\ 200,\ 300,\ 400,\ 500$. For $t = 500$ the image should be indistinguishable from pure noise.
start_alpha_t = 0.9999
end_alpha_t = 0.98
max_timestep = 1000
decrease_rate = (end_alpha_t - start_alpha_t) / max_timestep # The rate that alpha decreases over t, decrease_rate is negative
alpha_t_list = torch.arange(start_alpha_t, end_alpha_t, decrease_rate).to(device) # alpha_t+1 = alpha_t + decrease_rate
alpha_hat_t_list = torch.cumprod(alpha_t_list, dim=0).to(device) # Cumulative product, eg. alpha_hat_t = alpha_hat_t-1 * alpha_hat_t-2 * ... * alpha_hat_0
# Define a function to generate a seqeunce of noised images.
# Inputs: X = a batch of images
# times = a list of timesteps
def generate_noise_progression(X, times):
noise = torch.randn(X.shape).to(device)
assert times.shape[0] == X.shape[0]
alpha_hat_t = alpha_hat_t_list[times].view(-1, 1, 1, 1)
return torch.sqrt(alpha_hat_t) * X + (noise * torch.sqrt(1 - alpha_hat_t))
times = [99, 199, 299, 399, 499]
noised_sample = generate_noise_progression(sample_train[0].unsqueeze(0).repeat(len(times), 1,1, 1).to(device), torch.tensor(times).to(device))
fig, axs = plt.subplots(1, len(times))
for i, sample in enumerate(noised_sample):
(axs[i].imshow(sample[0].detach().cpu().numpy(), cmap='gray'))
It remains to choose the noise predictor model $\hat\epsilon (x_t, t, w)$. We choose a time-conditional U-Net model for this purpose. A U-Net is a convolutional neural network architecture initially used mainly for image segmentation tasks. It has also proven useful for denoising images which is why we use it here to predict the noise added to images. The U-Net features an encoder-decoder structure. The encoder, or contracting path, captures context through convolutional and pooling layers that down-sample the input image. The decoder, or expansive path, reconstructs the image using up-sampling layers, which are combined with high-resolution features from the encoder via skip connections. These skip connections help preserve spatial information. The architecture is symmetrical, resembling a U-shape hence the name. For the purposes of this lab, just note that the U-Net consists of Convolutional layers which we know are a good fit for processing images. The U-Net model that you will be using has been provided in the notebook.
We also need to implement the backward process, that we can later use to sample new images using our trained model.
Task 5 Write a sampling function that takes a trained U-Net, and a number $m$ as inputs and generates $m$ samples by running the backwards process described in (35).
# Define a function to generate images.
# Inputs: model = a trained U-Net model
# num_images = number of images to generate
def generate_images(estimator, num_images, timesteps=max_timestep):
with torch.no_grad():
x_t = torch.randn([num_images, 1, 32, 32]).to(device)
for t in trange(timesteps-1, 0, -1):
timestep = torch.tensor([t]).repeat(num_images).to(device).int()
noise_pred = estimator(x_t, timestep).sample # Estimator's noise prediction
alpha_t = alpha_t_list[timestep].view(-1, 1, 1, 1)
alpha_bar_t = alpha_hat_t_list[timestep].view(-1, 1, 1, 1)
sqrt_1_alpha_bar_t = (torch.sqrt(1 - alpha_bar_t))
sqrt_alpha_t = torch.sqrt(alpha_t).to(device)
noise = torch.randn([num_images, 1, 32, 32]).to(device) # Gaussian noise used to generate new samples
x_t = (1 / sqrt_alpha_t) * (x_t - ((1 - alpha_t) / sqrt_1_alpha_bar_t) * noise_pred ) + (torch.sqrt(1 - alpha_t)) * noise # Equation (35) - (37)
x_t = (x_t / 2 + 0.5).clamp(0, 1) # Scale x_t to the correct range for visualization
return x_t.detach().cpu().squeeze().squeeze().numpy()
Finally, we train our noise predictor model using gradient descent.
Task 6 Write a Pytorch training loop implementing Gradient Descent with the batch loss given in (40). Train the model for at least 200 epochs. Plot the training loss as a function of the epoch number. After training, generate 64 images using your trained model and the sampling function you wrote for task 5. Display them in an 8 by 8 grid.
A Few Notes For Implementation.
Training for 200 epochs on m = 3000 samples might take a long time depending on the system you are running on. To make things easier for yourself, you can write your code and make sure your training process works on a smaller number of samples first (e.g. 100 samples).(You will likely need fewer epochs as well). Do all of your debugging at this stage. Once you are sure that everything works with a small training set, you can do the training only once on the large dataset with m = 3000 samples and E = 200 epochs.
Also be sure to save your trained model somewhere so that if something happens to the notebook or you get disconnected from colab you won’t have to train a new model which will take a long time.
# Define a UNet model
in_channels = 1
out_channels = 1
act_fn = "silu"
block_out_channels = (128, 128, 256, 256, 512, 512)
down_block_types = (
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
)
up_block_types = (
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D"
)
estimator = UNet2DModel(sample_size=32, # Size of images (32 x 32)
in_channels=in_channels,
out_channels=out_channels,
act_fn=act_fn,
block_out_channels=block_out_channels,
down_block_types=down_block_types,
up_block_types=up_block_types
).to(device)
train_loss_evolution = []
optimizer = optim.Adam(estimator.parameters(), lr=1e-4)
num_epochs = 100
loss = torch.nn.MSELoss()
estimator.train()
for epoch in trange(num_epochs):
train_loss = 0
for i, (x, _) in enumerate(train_loader):
# (Step i) Load the data and send data to GPU memory
x = x.to(device)
# (Step ii) Compute the gradients. We use automated differentiation.
optimizer.zero_grad() # Gradient reset to indicate where the backward computation stops.
times = torch.randint(2, max_timestep, size=(x.shape[0],)).to(device) # Sample a random timestep from [2, max_step].
noised_x = generate_noise_progression(x, times) # Generate a sequence of noised images.
noise_pred = estimator(noised_x, times).sample # Call the diffusion model.
mse = loss(noise_pred, noise) # Call the loss functions.
mse.backward() # Compute gradients moving backwards untit the gradient reset.
# (Step iii) Update parameters by taking an SGD (or other optimizer) step.
optimizer.step()
train_loss += mse.item()
average_train_loss = train_loss / len(train_loader)
train_loss_evolution.append(average_train_loss)
100%|██████████| 100/100 [30:29<00:00, 18.30s/it]
# Plot training loss evolution
plt.plot(train_loss_evolution)
plt.xlabel('Epoch')
plt.ylabel('Train Loss')
plt.title('Training Loss Evolution')
plt.show()
print(f'Train Loss: {train_loss_evolution[-1]}') # Print training loss.
Train Loss: 0.015511260863314284
# Save trained estimator to Google Drive
estimator_save_path = '/content/drive/MyDrive/Diffusion_Model.pth'
torch.save(estimator.state_dict(), estimator_save_path)
print(f"Estimator saved to {estimator_save_path}")
# Generate 1000 images with trained diffusion model
generated_images = generate_images(estimator, 1000)
# Save generated samples to a folder named "Generated_samples" on your Google Drive
generate_save_dir = '/content/drive/MyDrive/Generated_samples'
os.makedirs(generate_save_dir, exist_ok=True)
for i, image in enumerate(torch.tensor(generated_images)):
save_image(image, os.path.join(generate_save_dir, f"generated_image_{i}.png"))
100%|██████████| 999/999 [27:29<00:00, 1.65s/it]
# Visualize the output
fig, axes = plt.subplots(nrows=8, ncols=8, figsize=(25, 25))
for idx, ax in enumerate(axes.flatten()):
ax.imshow(generated_images[idx], cmap='gray')
ax.axis('off')
ax.set_title(f'Generated image {idx + 1}')
plt.show()
8 Evaluating a Diffusion Model¶
Now that we have trained a diffusion model, it is time to evaluate it quantitatively. In order to do this we use a metric called the Frechet Inception Distance (FID) score. The FID score essentially measures the distance between the mean and covariance of samples generated by the generative model and the ‘ground truth’ mean and covariance computed from the data. A smaller FID score (therefore smaller distance) means the model has learned to capture the underlying data distribution more accurately. You can think of the FID as something akin to a test error. It is important to note that the distance is not computed in the pixel space but rather in the feature space of features extracted by a neural network.
Task 7 Compute the FID score of your trained model with respect to a large subset of the MNIST dataset (e.g. 10000 samples) using the cleanFID library. The more generated samples you use to compute the FID, the more accurate the score will be. However, generating samples using your trained model can be slow. Try using at least a few thousand generated samples.
# Calculate the FID scores using downloaded MNIST samples and generated samples
score = fid.compute_fid(mnist_save_dir, generate_save_dir)
print(f'FID score: {score}')
compute FID between two folders
/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:617: UserWarning: This DataLoader will create 12 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(
Found 10000 images in the folder /content/drive/MyDrive/MNIST_samples
FID MNIST_samples : 100%|██████████| 313/313 [04:21<00:00, 1.20it/s]
Found 1000 images in the folder /content/drive/MyDrive/Generated_samples
FID Generated_samples : 100%|██████████| 32/32 [00:28<00:00, 1.12it/s]
FID score: 86.39788071634948