您的位置:

VAE模型详解

一、概述

Variational Autoencoder(VAE)是一种生成模型,广泛应用于图像与文本生成等领域。它可以将数据映射到一个潜在空间中,并通过解码器从这个潜在空间重新生成出输入数据。

相较于其他生成模型,VAE采用了贝叶斯推断的方法,能够更好地描述数据的不确定性。其目标是使数据在潜在空间中服从一个特定的分布,从而使得通过这个分布采样的数据与真实数据尽量相似。

VAE包含两个主要的部分:编码器和解码器。编码器将输入数据压缩到潜在空间中,解码器则从潜在空间中重建出数据。在这个过程中,中间的潜在空间起到了"过渡"的作用,即将输入数据从原始空间映射到潜在空间,再从潜在空间映射回原始空间。通过对潜在空间的建模,我们可以生成与数据分布相似的新数据。

二、编码器与解码器

编码器和解码器是VAE的核心组成部分。编码器将输入数据x映射到潜在空间z中的一个概率分布,解码器则从潜在空间z中采样,并生成与原始数据x相似的新数据。

2.1 编码器

编码器的主要目的是将输入数据x映射到潜在空间z中的一个概率分布,即求解p(z|x)。在VAE中,我们假设p(z|x)是一个高斯分布,其均值和方差可以用x计算得到:


class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc21 = nn.Linear(hidden_size, latent_size)
        self.fc22 = nn.Linear(hidden_size, latent_size)
        
    def forward(self, x):
        hidden = F.relu(self.fc1(x))
        mu = self.fc21(hidden)
        logvar = self.fc22(hidden)
        return mu, logvar

在上面的代码中,Encoder是编码器的实现,其输入为x,输出为潜在空间的均值mu和对数方差logvar。两个分布之间的KL散度可以用以下公式计算:


def kl_loss(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

上述代码中kl_loss函数用于计算KL散度。

2.2 解码器

解码器的主要目的是从潜在空间z中采样出一组随机向量z,并通过解码器将其映射回到原始数据空间中,即求解p(x|z)。一般的,我们假设p(x|z)是一个高斯分布,其均值与方差可以用z计算得到:


class Decoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(latent_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, z):
        hidden = F.relu(self.fc1(z))
        output = torch.sigmoid(self.fc2(hidden))
        return output

上述代码中Decoder是解码器的实现,其输入为潜在空间的向量z,输出为生成的图像数据output。

三、VAE的训练

训练VAE的目标是最小化重建损失和KL散度。重建损失用于度量生成数据与真实数据之间的差异,即使得通过解码器生成的数据尽量接近真实数据。KL散度用于度量从真实数据分布到潜在空间分布的距离,使得生成的数据与真实数据在潜在空间分布上尽量接近。

3.1 重建损失

重建损失的计算是相对简单的,即通过解码器从潜在空间中采样随机向量,并计算生成数据与真实数据之间的欧几里得距离:


def reconstruction_loss(x, x_origin):
    return F.mse_loss(x_origin, x, reduction='sum')

上述代码中reconstruction_loss函数用于计算重建损失。

3.2 KL散度

KL散度中的μ和logσ都是计算得到的,具体如下:


def loss_function(x, x_origin, mu, logvar):
    BCE = reconstruction_loss(x, x_origin)
    KLD = kl_loss(mu, logvar)
    return BCE + KLD, BCE, KLD

上述代码中的loss_function函数是整个VAE的损失函数,其输入为真实数据x和生成数据x_origin,以及潜在空间中的均值mu和对数方差logvar。在训练时,我们将重建损失和KL散度加权相加,得到整个VAE的损失函数。其中,BCE代表重建损失,KLD代表KL散度。

四、代码示例

下面是一个完整的VAE模型的代码实现:


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc21 = nn.Linear(hidden_size, latent_size)
        self.fc22 = nn.Linear(hidden_size, latent_size)
        
    def forward(self, x):
        hidden = F.relu(self.fc1(x))
        mu = self.fc21(hidden)
        logvar = self.fc22(hidden)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_size, hidden_size, output_size):
        super().__init__()
        self.fc1 = nn.Linear(latent_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, z):
        hidden = F.relu(self.fc1(z))
        output = torch.sigmoid(self.fc2(hidden))
        return output

def kl_loss(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

def reconstruction_loss(x, x_origin):
    return F.mse_loss(x_origin, x, reduction='sum')

def loss_function(x, x_origin, mu, logvar):
    BCE = reconstruction_loss(x, x_origin)
    KLD = kl_loss(mu, logvar)
    return BCE + KLD, BCE, KLD

class VAE(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size):
        super().__init__()
        self.encoder = Encoder(input_size, hidden_size, latent_size)
        self.decoder = Decoder(latent_size, hidden_size, input_size)
        
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, input_size))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# train
input_size = 784
hidden_size = 256
latent_size = 10
epochs = 10
batch_size = 64
lr = 0.001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

model = VAE(input_size, hidden_size, latent_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(1, epochs+1):
    model.train()
    train_loss = 0
    train_BCE_loss = 0
    train_KLD_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss, BCE_loss, KLD_loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        train_BCE_loss += BCE_loss.item()
        train_KLD_loss += KLD_loss.item()
        optimizer.step()
        
    print('Epoch [{}/{}], Loss: {:.4f}, BCE Loss: {:.4f}, KL Divergence: {:.4f}'.format(
                epoch, epochs, train_loss / len(train_loader.dataset), 
                train_BCE_loss / len(train_loader.dataset),
                train_KLD_loss / len(train_loader.dataset)))