前向过程
在扩散模型的前向过程中,数据会经历一个连续的噪声添加过程,直到原始信号完全退化,从而得到一个良好的分布,通常是高斯分布。在每个时间步长 t,我们使用高斯分布 q(x_{t+1}|x_{t})从当前数据点 x_{t} 采样下一个数据点,表示为 x_{t+1}。该分布的均值和协方差由一组称为贝塔的超参数决定。
在对数据点 x_{t} 进行采样时,我们不需要使用在 t 个时间步长内进行迭代的天真解决方案,因为条件分布 q(x_t|x_0) 已被证明具有可行的解析形式。因此,我们可以直接将 x_t 作为 x_0 的函数进行采样,这在实现过程的第 7 行完成。这种方法无需耗时的迭代循环,可以根据 x_0 对 x_t 进行高效采样。
在第 7 行使用 q(x_t|x_0) 的均值和协方差对 x_t 进行采样时,我们需要 q(x_{t-1}|x_t, x_0) 的均值和标准差来计算损失。这些变量可以用贝叶斯定理分析计算,然后直接并入代码中
def forward_process(self, x0, t):
t = t - 1 # Start indexing at 0
beta_forward = self.beta[t]
alpha_forward = self.alpha[t]
alpha_cum_forward = self.alpha_bar[t]
xt = x0 * torch.sqrt(alpha_cum_forward) + torch.randn_like(x0) * torch.sqrt(1. - alpha_cum_forward)
mu1_scl = torch.sqrt(alpha_cum_forward / alpha_forward)
mu2_scl = 1. / torch.sqrt(alpha_forward)
cov1 = 1. - alpha_cum_forward / alpha_forward
cov2 = beta_forward / alpha_forward
lam = 1. / cov1 + 1. / cov2
mu = (x0 * mu1_scl / cov1 + xt * mu2_scl / cov2) / lam
sigma = torch.sqrt(1. / lam)
return mu, sigma, xt
逆过程
扩散模型中反向过程的目的是在每个时间步长 t 近似正向过程的逆过程。这意味着一旦模型训练完成,我们就可以直接从良好分布中采样噪声,并利用反向过程生成数据样本,如二维图像。
反向过程的实现非常简单。由于模型的构造,q(x_{t-1}|x_{t-1}) 是一个标准偏差很小的高斯分布,因此 p(x_{t-1}|x_t) 也是一个高斯分布。因此,可以通过其均值和协方差矩阵对其进行参数化。我们的目标是训练模型,使其在每个时间步精确估计这些参数,从而生成高质量的合成数据。
在模型逼近分布 p(x_{t-1}|x_t) 的参数(均值和协方差)后,就可以利用这些参数从分布中采样。
def reverse(self, xt, t):
t = t - 1 # Start indexing at 0
if t == 0: return None, None, xt
mu, h = self.model(xt, t).chunk(2, dim=1)
sigma = torch.sqrt(torch.exp(h))
samples = mu + torch.randn_like(xt) * sigma
return mu, sigma, samples
采样
反向过程一旦实施,就可以递归调用,首先从良好分布中抽取噪声。通过这种递归调用,反向过程生成的数据有望与训练数据相似。
def sample(self, size, device):
noise = torch.randn((size, 2)).to(device)
samples = [noise]
for t in range(self.n_steps):
_, _, x = self.reverse(samples[-1], self.n_steps - t - 1 + 1)
samples.append(x)
return samples
构造函数
为了完成模型,我们将创建构造函数。它需要几个输入,包括随时间变化的模型、扩散步数以及执行该函数的设备。在第 10 行,我们定义了超参数 beta,它代表前向过程中每个扩散步骤的方差。此外,在第 12 行和第 13 行,我们定义了 beta 的函数变量,这有助于使其他部分的实现更加简洁。
class DiffusionModel(nn.Module):
def __init__(self, model: nn.Module, n_steps=40, device='cuda'):
super().__init__()
self.model = model
self.device = device
betas = torch.linspace(-18, 10, n_steps)
self.beta = torch.sigmoid(betas) * (3e-1 - 1e-5) + 1e-5
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.n_steps = n_steps
self.sigma2 = self.beta
训练
在训练过程中,我们的目标是使模型的对数似然最大化,但遗憾的是,模型的对数似然无法通过分析计算出来。不过,另一种方法是计算模型似然的下限,正如本文所演示的那样。这个下限涉及计算 q(x_{t-1}|x_t) 和 p(x_{t-1}|x_t) 分布之间的库尔巴克-莱伯勒(KL)发散。由于两个分布都是高斯分布,KL 发散可以直接用著名的解析形式计算。通过优化下限,我们可以间接地最大化模型的对数似然,并在训练过程中提高模型的性能。
def train(model, optimizer, nb_epochs=150_000, batch_size=64_000):
training_loss = []
for _ in tqdm(range(nb_epochs)):
x0 = torch.from_numpy(sample_batch(batch_size)).float().to(device)
t = np.random.randint(2, 40 + 1)
mu_posterior, sigma_posterior, xt = model.forward_process(x0, t)
mu, sigma, _ = model.reverse(xt, t)
KL = (torch.log(sigma) - torch.log(sigma_posterior) + (sigma_posterior ** 2 + (mu_posterior - mu) ** 2) / (
2 * sigma ** 2) - 0.5)
loss = KL.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_loss.append(loss.item())
组合在一起
最后,可以毫不费力地将所有组件组合在一起,形成完整的实施方案。经过几个小时的训练,该模型应该能够生成一个出色的瑞士卷分布生成模型。
if __name__ == "__main__":
device = 'cuda'
model_mlp = MLP(hidden_dim=128).to(device)
model = DiffusionModel(model_mlp)
optimizer = torch.optim.Adam(model_mlp.parameters(), lr=1e-4)
train(model, optimizer)
完整代码:
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
import torch.utils.data
import matplotlib.pyplot as plt
from sklearn.datasets import make_swiss_roll
def sample_batch(size):
x, _ = make_swiss_roll(size)
return x[:, [2, 0]] / 10.0 * np.array([1, -1])
class MLP(nn.Module):
def __init__(self, N=40, data_dim=2, hidden_dim=64):
super(MLP, self).__init__()
self.network_head = nn.Sequential(nn.Linear(data_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), )
self.network_tail = nn.ModuleList([nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(), nn.Linear(hidden_dim, data_dim * 2)
) for _ in range(N)])
def forward(self, x, t: int):
h = self.network_head(x)
return self.network_tail[t](h)
class DiffusionModel(nn.Module):
def __init__(self, model: nn.Module, n_steps=40, device='cuda'):
super().__init__()
self.model = model
self.device = device
betas = torch.linspace(-18, 10, n_steps)
self.beta = torch.sigmoid(betas) * (3e-1 - 1e-5) + 1e-5
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.n_steps = n_steps
self.sigma2 = self.beta
def forward_process(self, x0, t):
t = t - 1 # Start indexing at 0
beta_forward = self.beta[t]
alpha_forward = self.alpha[t]
alpha_cum_forward = self.alpha_bar[t]
xt = x0 * torch.sqrt(alpha_cum_forward) + torch.randn_like(x0) * torch.sqrt(1. - alpha_cum_forward)
# Retrieved from https://github.com/Sohl-Dickstein/Diffusion-Probabilistic-Models/blob/master/model.py#L203
mu1_scl = torch.sqrt(alpha_cum_forward / alpha_forward)
mu2_scl = 1. / torch.sqrt(alpha_forward)
cov1 = 1. - alpha_cum_forward / alpha_forward
cov2 = beta_forward / alpha_forward
lam = 1. / cov1 + 1. / cov2
mu = (x0 * mu1_scl / cov1 + xt * mu2_scl / cov2) / lam
sigma = torch.sqrt(1. / lam)
return mu, sigma, xt
def reverse(self, xt, t):
t = t - 1 # Start indexing at 0
if t == 0: return None, None, xt
mu, h = self.model(xt, t).chunk(2, dim=1)
sigma = torch.sqrt(torch.exp(h))
samples = mu + torch.randn_like(xt) * sigma
return mu, sigma, samples
def sample(self, size, device):
noise = torch.randn((size, 2)).to(device)
samples = [noise]
for t in range(self.n_steps):
_, _, x = self.reverse(samples[-1], self.n_steps - t - 1 + 1)
samples.append(x)
return samples
def plot(model):
plt.figure(figsize=(10, 6))
x0 = sample_batch(5000)
x20 = model.forward_process(torch.from_numpy(x0).to(device), 20)[-1].data.cpu().numpy()
x40 = model.forward_process(torch.from_numpy(x0).to(device), 40)[-1].data.cpu().numpy()
data = [x0, x20, x40]
for i, t in enumerate([0, 20, 39]):
plt.subplot(2, 3, 1 + i)
plt.scatter(data[i][:, 0], data[i][:, 1], alpha=.1, s=1)
plt.xlim([-2, 2])
plt.ylim([-2, 2])
plt.gca().set_aspect('equal')
if t == 0: plt.ylabel(r'$q(\mathbf{x}^{(0...T)})$', fontsize=17, rotation=0, labelpad=60)
if i == 0: plt.title(r'$t=0$', fontsize=17)
if i == 1: plt.title(r'$t=\frac{T}{2}$', fontsize=17)
if i == 2: plt.title(r'$t=T$', fontsize=17)
samples = model.sample(5000, device)
for i, t in enumerate([0, 20, 40]):
plt.subplot(2, 3, 4 + i)
plt.scatter(samples[40 - t][:, 0].data.cpu().numpy(), samples[40 - t][:, 1].data.cpu().numpy(),
alpha=.1, s=1, c='r')
plt.xlim([-2, 2])
plt.ylim([-2, 2])
plt.gca().set_aspect('equal')
if t == 0: plt.ylabel(r'$p(\mathbf{x}^{(0...T)})$', fontsize=17, rotation=0, labelpad=60)
plt.savefig(f"Imgs/diffusion_model.png", bbox_inches='tight')
plt.close()
def train(model, optimizer, nb_epochs=150_000, batch_size=64_000):
training_loss = []
for _ in tqdm(range(nb_epochs)):
x0 = torch.from_numpy(sample_batch(batch_size)).float().to(device)
t = np.random.randint(2, 40 + 1)
mu_posterior, sigma_posterior, xt = model.forward_process(x0, t)
mu, sigma, _ = model.reverse(xt, t)
KL = (torch.log(sigma) - torch.log(sigma_posterior) + (sigma_posterior ** 2 + (mu_posterior - mu) ** 2) / (
2 * sigma ** 2) - 0.5)
loss = KL.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_loss.append(loss.item())
if __name__ == "__main__":
device = 'cuda'
model_mlp = MLP(hidden_dim=128).to(device)
model = DiffusionModel(model_mlp)
optimizer = torch.optim.Adam(model_mlp.parameters(), lr=1e-4)
train(model, optimizer)
plot(model)
。