Intro to Diffusion Model — Part 7
In this final post in the series, we are going to train our diffusion model and see the results.
This post is part of a series of posts about diffusion models:
- Intro to Diffusion Model — Part 1
- Intro to Diffusion Model — Part 2
- Intro to Diffusion Model — Part 3
- Intro to Diffusion Model — Part 4
- Intro to Diffusion Model — Part 5
- Intro to Diffusion Model — Part 6
- Intro to Diffusion Model — Part 7 (this post)
- Full implementation
Before we train, we need a dataset. I chose to work here with the Smithsonian butterflies dataset from Hugging Face. First, we need to install the Hugging Face dataset module.
pip install datasets[vision] -q
Then, we can load the Smithsonian butterflies dataset.
from datasets import load_dataset
dataset = load_dataset("huggan/smithsonian_butterflies_subset", split='train')
Now, we want to create a data loader with this dataset and apply transformation in a similar way to what we did in part 4.
from torchvision.transforms import RandomHorizontalFlip, Compose, ToTensor, Resize, Normalize
image_size = 64
transform = Compose([
Resize((image_size, image_size)),
RandomHorizontalFlip(),
ToTensor(),
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
def transforms(data):
images = [transform(im) for im in data['image']]
return {'images': images}
dataset.set_transform(transforms)
batch_size=32
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
We can look at an example from the dataset with
batch = next(iter(train_dataloader))
reverse_transform_pil(batch['images'][20])
We need to define a loss for the training. Here, we are going to use L1.
import torch.nn.functional as F
def compute_loss(model, x0, t, noise=None):
if noise is None:
noise = torch.randn_like(x0)
x_t = sample_q(x0, t, noise)
predicted_noise = model(x_t, t)
loss = F.l1_loss(noise, predicted_noise)
return loss
Now, we can initialize the model and set an optimizer.
from torch.optim import Adam
device = "cuda" if torch.cuda.is_available() else "cpu"
model = DiffusionUnet(dim=image_size, channels=3, dim_mults=(1, 2, 4, 8)).to(device)
optimizer = Adam(model.parameters(), lr=1e-4)
We will want to save results along the training, so we create a results folder.
from pathlib import Path
results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
For the training, I’m going to use Google Collab, so I want to save the model to my Google Drive and continue the training later in case of a timeout. For that, we need to connect to Google Drive first.
from google.colab import drive
drive.mount('/content/gdrive')
Now, we check if we already have a model saved in our drive. If there is, we load it; otherwise, we create an empty dictionary.
model_saved_file_path = Path('/content/gdrive/MyDrive/diffusion_checkpoints/saved_model.pth')
if model_saved_file_path.exists():
saved_model = torch.load(str(model_saved_file_path))
else:
model_saved_file_path.parent.mkdir(parents=True, exist_ok=True)
saved_model = {}
start_epoch = 0
if saved_model:
print('loading model')
model.load_state_dict(saved_model['model'])
optimizer.load_state_dict(saved_model['optimizer'])
current_epoch = saved_model['epoch']
And finally, we implement the training loop.
import numpy as np
from torchvision.utils import save_image
epochs = 1000
loss_steps = 50
sample_every = 1000
loss_for_mean = np.zeros(loss_steps)
prev_time = time.time()
for epoch in range(start_epoch, epochs):
for batch_index, batch in enumerate(train_dataloader):
images = batch['images'].to(device)
# sample t according to Algorithm 1
t = torch.randint(1, num_timesteps, (images.shape[0],), device=device).long()
loss = compute_loss(model, images, t)
current_step = batch_index + epoch * len(train_dataloader)
if current_step % loss_steps == 0:
# Determine approximate time left
batches_done = epoch * len(train_dataloader) + batch_index
batches_left = epochs * len(train_dataloader) - current_step
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time) / loss_steps)
prev_time = time.time()
print(f'Loss at epoch {epoch}, batch {batch_index}: {loss_for_mean.mean()} | time remaining: {time_left}')
loss_for_mean[:] = 0
loss_for_mean[current_step%loss_steps] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if current_step % sample_every == 0:
batch_to_sample = 5
sample_images_list = sampling(model, (batch_to_sample, 3, image_size, image_size))
sample_images = torch.cat(sample_images_list, dim=0)
sample_images = reverse_transform_tensor(sample_images)
save_image(sample_images, str(results_folder / f'sample_{current_step}.png'), nrow=batch_to_sample)
saved_model['epoch'] = epoch + 1
saved_model['model'] = model.state_dict()
saved_model['optimizer']= optimizer.state_dict()
torch.save(saved_model, '/content/gdrive/MyDrive/diffusion_checkpoints/saved_model.pth')
We train according to Algorithm 1 from part 3. We get a batch of clean images and sample uniformly the timestamp, and then we try to predict the noise we have added to the images and compute the loss. We save the loss values for loss_steps
steps (a step is one batch from the data loader), and after that, we print the mean of the loss over the last loss_steps
steps. Pay attention that the loss printed the first time is meaningless since the loss_for_mean
buffer is not full yet. We also estimate the time remaining and print it with the loss. We update the weights of the model, and once every sample_every
steps we sample from the reverse process and save the result. Finally, at the end of each epoch, we save the states of the model and the optimizer together with the current epoch index.
After finishing the training, we can sample new butterfly images with the following code:
from torchvision.utils import make_grid
reverse_transform_pil(make_grid(sampling(model, (16, 3, 64, 64))[-1], nrow=8))
We can also see the reverse process using
from torchvision.utils import make_grid
image_steps_list = sampling(model, (4, 3, 64, 64), 300)
image_steps = torch.cat(image_steps_list, dim=0)
reverse_transform_pil(make_grid(image_steps, nrow=4))
The code above will show the last 300 steps of sampling 4 butterflies.