# Intro to Diffusion Model — Part 4

## In this post, we are going to talk more about the forward process. We define the noise schedule and write some code to implement this process.

This post is part of a series of posts about diffusion models:

- Intro to Diffusion Model — Part 1
- Intro to Diffusion Model — Part 2
- Intro to Diffusion Model — Part 3
- Intro to Diffusion Model — Part 4 (this post)
- Intro to Diffusion Model — Part 5
- Intro to Diffusion Model — Part 6
- Intro to Diffusion Model — Part 7
- Full implementation

In part 2, we defined the forward process. Recall that in this process, we gradually add noise to the original image in a series of time steps. At each time step *t*, the variance of the noise added to the previous step is predetermined by the variance schedule, and it’s marked as βₜ, where 0 < β₁ < … < β_T < 1.

The variance schedule can be of different types, such as linear, cosine quadratic, etc. In the paper Denoising Diffusion Probabilistic Models, they used a linear time schedule, meaning β grows linearly from some initial value to some final value (in the paper, they used β₁ = 10⁻⁴ and β_T = 0.02).

In a later paper, Improved Denoising Diffusion Probabilistic Models, it was shown that using a cosine schedule provides better results than the linear one.

Let’s implement these two schedules.

`def linear_schedule(num_timesteps):`

beta_start = 1e-4

beta_end = 0.02

betas = torch.linspace(beta_start, beta_end, num_timesteps)

return betas

def cosine_schedule(num_timesteps, s=0.008):

def f(t):

return torch.cos((t / num_timesteps + s) / (1 + s) * 0.5 * torch.pi) ** 2

x = torch.linspace(0, num_timesteps, num_timesteps + 1)

alphas_cumprod = f(x) / f(torch.tensor([0]))

betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]

betas = torch.clip(betas, 0.0001, 0.999)

return betas

Now, we want to define a function for sampling from the forward process. First, let’s define a helper function to sample a tensor sorted by time, according to timesteps.

`def sample_by_t(tensor_to_sample, timesteps, x_shape):`

batch_size = timesteps.shape[0]

sampled_tensor = tensor_to_sample.gather(-1, timesteps.cpu())

sampled_tensor = torch.reshape(sampled_tensor, (batch_size,) + (1,) * (len(x_shape) - 1))

return sampled_tensor.to(timesteps.device)

and define αₜ, α̅ₜ, βₜ and some operation on them as we need according to the developments in part 2 and part 3.

`num_timesteps = 300`

betas_t = linear_schedule(num_timesteps)

alphas_t = 1. - betas_t

alphas_bar_t = torch.cumprod(alphas_t, dim=0)

alphas_bar_t_minus_1 = torch.cat((torch.tensor([0]), alphas_bar_t[:-1]))

one_over_sqrt_alphas_t = 1. / torch.sqrt(alphas_t)

sqrt_alphas_bar_t = torch.sqrt(alphas_bar_t)

sqrt_1_minus_alphas_bar_t = torch.sqrt(1. - alphas_bar_t)

# the variance of q(xₜ₋₁ | xₜ, x₀) as in part 3

posterior_variance = (1. - alphas_bar_t_minus_1) / (1. - alphas_bar_t) * betas_t

The forward process sampling function is

`def sample_q(x0, t, noise=None):`

if noise is None:

noise = torch.randn_like(x0)

sqrt_alphas_bar_t_sampled = sample_by_t(sqrt_alphas_bar_t, t, x0.shape)

sqrt_1_minus_alphas_bar_t_sampled = sample_by_t(sqrt_1_minus_alphas_bar_t, t, x0.shape)

x_t = sqrt_alphas_bar_t_sampled * x0 + sqrt_1_minus_alphas_bar_t_sampled * noise

return x_t

where we use the relation from part 2

Now, it’s time to see how an image will look when adding a different amount of noise during the forward process. For that, let’s start by loading an image to work with. We are going to use `PIL`

and the `requests`

modules to get an image from the web.

`import requests`

from PIL import Image

url = 'https://images.pexels.com/photos/1557208/pexels-photo-1557208.jpeg?auto=compress&cs=tinysrgb&w=1260&h=750&dpr=2'

image_raw_data = requests.get(url, stream=True).raw

image = Image.open(image_raw_data)

To work with the image, we need to convert it into a Pytorch tensor. We also want to transform it according to some standardization:

- A square shape
- Predefined size
- Values in the range [-1, 1]

We can achieve this with the help of the `transformation`

module from `torchvision`

.

`from torchvision.transforms import Compose, ToTensor, CenterCrop, Resize, Normalize`

image_size = 128

transform = Compose([

Resize(image_size), # resize smaller edge to image_size

CenterCrop(image_size), # make a square image with size image_size

ToTensor(), # convert to tensor with shape CHW and values in the range [0, 1]

Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)). # set the values to the range [-1, 1]

])

We also want a reverse transformation to return the image into a PIL object and one for getting a tensor with values in the range [0, 1].

`from torchvision.transforms import ToPILImage`

reverse_transform_pil = Compose([

Normalize(mean=(-1, -1, -1), std=(2, 2, 2)),

ToPILImage()

])

reverse_transform_tensor = Compose([

Normalize(mean=(-1, -1, -1), std=(2, 2, 2)),

])

Pay attention that the image is still a square shape and size `image_size`

after the reverse transformation.

We can define a function that, given an image as a tensor and a timestep *t*, returns a noisy image sampled from the *q* distribution.

`def get_noisy_image(x0, t, transform=reverse_transform_pil):`

x_noisy = sample_q(x0, t)

noise_image = transform(x_noisy.squeeze())

return noise_image

and a function to display noisy images as a grid where the rows are different images, and the columns are the noisy images in different timesteps.

`import matplotlib.pyplot as plt`

def show_noisy_images(noisy_images):

"""

Show and return a grid of noisy images where

the rows are different images, and the columns

are the noisy images in different timesteps.

Args:

noisy_images (list[list[PIL]]): a list with a lists of images

with noise from different timesteps.

"""

num_of_image_sets = len(noisy_images)

num_of_images_in_set = len(noisy_images[0])

image_size = noisy_images[0][0].size[0]

full_image = Image.new('RGB', (image_size * num_of_images_in_set + (num_of_images_in_set - 1), image_size * num_of_image_sets + (num_of_image_sets - 1)))

for set_index, image_set in enumerate(noisy_images):

for image_index, image in enumerate(image_set):

full_image.paste(image, (image_index * image_size + image_index, set_index * image_size + set_index))

plt.imshow(full_image)

plt.axis('off')

return full_image

Let’s see an example.

`show_noisy_images([[get_noisy_image(x0, torch.tensor([t])) for t in [0, 50, 100, 150, 200]]])`