一、概述
Variational Autoencoder(VAE)是一种生成模型,广泛应用于图像与文本生成等领域。它可以将数据映射到一个潜在空间中,并通过解码器从这个潜在空间重新生成出输入数据。
相较于其他生成模型,VAE采用了贝叶斯推断的方法,能够更好地描述数据的不确定性。其目标是使数据在潜在空间中服从一个特定的分布,从而使得通过这个分布采样的数据与真实数据尽量相似。
VAE包含两个主要的部分:编码器和解码器。编码器将输入数据压缩到潜在空间中,解码器则从潜在空间中重建出数据。在这个过程中,中间的潜在空间起到了"过渡"的作用,即将输入数据从原始空间映射到潜在空间,再从潜在空间映射回原始空间。通过对潜在空间的建模,我们可以生成与数据分布相似的新数据。
二、编码器与解码器
编码器和解码器是VAE的核心组成部分。编码器将输入数据x映射到潜在空间z中的一个概率分布,解码器则从潜在空间z中采样,并生成与原始数据x相似的新数据。
2.1 编码器
编码器的主要目的是将输入数据x映射到潜在空间z中的一个概率分布,即求解p(z|x)。在VAE中,我们假设p(z|x)是一个高斯分布,其均值和方差可以用x计算得到:
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, latent_size):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc21 = nn.Linear(hidden_size, latent_size)
self.fc22 = nn.Linear(hidden_size, latent_size)
def forward(self, x):
hidden = F.relu(self.fc1(x))
mu = self.fc21(hidden)
logvar = self.fc22(hidden)
return mu, logvar
在上面的代码中,Encoder是编码器的实现,其输入为x,输出为潜在空间的均值mu和对数方差logvar。两个分布之间的KL散度可以用以下公式计算:
def kl_loss(mu, logvar):
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
上述代码中kl_loss函数用于计算KL散度。
2.2 解码器
解码器的主要目的是从潜在空间z中采样出一组随机向量z,并通过解码器将其映射回到原始数据空间中,即求解p(x|z)。一般的,我们假设p(x|z)是一个高斯分布,其均值与方差可以用z计算得到:
class Decoder(nn.Module):
def __init__(self, latent_size, hidden_size, output_size):
super().__init__()
self.fc1 = nn.Linear(latent_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, z):
hidden = F.relu(self.fc1(z))
output = torch.sigmoid(self.fc2(hidden))
return output
上述代码中Decoder是解码器的实现,其输入为潜在空间的向量z,输出为生成的图像数据output。
三、VAE的训练
训练VAE的目标是最小化重建损失和KL散度。重建损失用于度量生成数据与真实数据之间的差异,即使得通过解码器生成的数据尽量接近真实数据。KL散度用于度量从真实数据分布到潜在空间分布的距离,使得生成的数据与真实数据在潜在空间分布上尽量接近。
3.1 重建损失
重建损失的计算是相对简单的,即通过解码器从潜在空间中采样随机向量,并计算生成数据与真实数据之间的欧几里得距离:
def reconstruction_loss(x, x_origin):
return F.mse_loss(x_origin, x, reduction='sum')
上述代码中reconstruction_loss函数用于计算重建损失。
3.2 KL散度
KL散度中的μ和logσ都是计算得到的,具体如下:
def loss_function(x, x_origin, mu, logvar):
BCE = reconstruction_loss(x, x_origin)
KLD = kl_loss(mu, logvar)
return BCE + KLD, BCE, KLD
上述代码中的loss_function函数是整个VAE的损失函数,其输入为真实数据x和生成数据x_origin,以及潜在空间中的均值mu和对数方差logvar。在训练时,我们将重建损失和KL散度加权相加,得到整个VAE的损失函数。其中,BCE代表重建损失,KLD代表KL散度。
四、代码示例
下面是一个完整的VAE模型的代码实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, latent_size):
super().__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc21 = nn.Linear(hidden_size, latent_size)
self.fc22 = nn.Linear(hidden_size, latent_size)
def forward(self, x):
hidden = F.relu(self.fc1(x))
mu = self.fc21(hidden)
logvar = self.fc22(hidden)
return mu, logvar
class Decoder(nn.Module):
def __init__(self, latent_size, hidden_size, output_size):
super().__init__()
self.fc1 = nn.Linear(latent_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, z):
hidden = F.relu(self.fc1(z))
output = torch.sigmoid(self.fc2(hidden))
return output
def kl_loss(mu, logvar):
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
def reconstruction_loss(x, x_origin):
return F.mse_loss(x_origin, x, reduction='sum')
def loss_function(x, x_origin, mu, logvar):
BCE = reconstruction_loss(x, x_origin)
KLD = kl_loss(mu, logvar)
return BCE + KLD, BCE, KLD
class VAE(nn.Module):
def __init__(self, input_size, hidden_size, latent_size):
super().__init__()
self.encoder = Encoder(input_size, hidden_size, latent_size)
self.decoder = Decoder(latent_size, hidden_size, input_size)
def encode(self, x):
return self.encoder(x)
def decode(self, z):
return self.decoder(z)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encode(x.view(-1, input_size))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# train
input_size = 784
hidden_size = 256
latent_size = 10
epochs = 10
batch_size = 64
lr = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
model = VAE(input_size, hidden_size, latent_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(1, epochs+1):
model.train()
train_loss = 0
train_BCE_loss = 0
train_KLD_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss, BCE_loss, KLD_loss = loss_function(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
train_BCE_loss += BCE_loss.item()
train_KLD_loss += KLD_loss.item()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}, BCE Loss: {:.4f}, KL Divergence: {:.4f}'.format(
epoch, epochs, train_loss / len(train_loader.dataset),
train_BCE_loss / len(train_loader.dataset),
train_KLD_loss / len(train_loader.dataset)))