一、什么是VAE模型
VAE全称为Variational Autoencoder,是一种生成模型。VAE通过将输入数据映射到潜在空间中,实现对样本的压缩和重构,并且通过引入潜在变量来控制生成数据的分布,从而可以生成新的数据样本。
VAE模型的主要特点是使用了变分下界来优化模型,从而让模型在训练过程中更加稳定,同时可以利用VAE学到的潜在空间进行插值、生成多个样本等任务。
下面是使用PyTorch实现VAE模型的示例代码:
import torch import torch.nn as nn import torch.nn.functional as F class VAE(nn.Module): def __init__(self, input_size, hidden_size, latent_size): super(VAE, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.latent_size = latent_size # encoder self.enc_fc1 = nn.Linear(input_size, hidden_size) self.enc_fc2_mean = nn.Linear(hidden_size, latent_size) self.enc_fc2_logvar = nn.Linear(hidden_size, latent_size) # decoder self.dec_fc1 = nn.Linear(latent_size, hidden_size) self.dec_fc2 = nn.Linear(hidden_size, input_size) def encode(self, x): h = F.relu(self.enc_fc1(x)) mean = self.enc_fc2_mean(h) logvar = self.enc_fc2_logvar(h) return mean, logvar def reparameterize(self, mean, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mean + eps * std def decode(self, z): h = F.relu(self.dec_fc1(z)) x_hat = torch.sigmoid(self.dec_fc2(h)) return x_hat def forward(self, x): mean, logvar = self.encode(x) z = self.reparameterize(mean, logvar) x_hat = self.decode(z) return x_hat, mean, logvar
二、VAE模型在图像生成中的应用
VAE模型在图像生成中的应用是在潜在空间中生成新的样本。通常情况下,我们可以使用VAE将输入图片编码成一个低维的向量,然后在潜在空间中随机采样,最后将采样到的向量解码成新的图片。
下面是使用VAE模型在MNIST数据集上进行图片生成的示例代码:
import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt # load data transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) # train model device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = VAE(input_size=784, hidden_size=512, latent_size=20).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) def loss_function(x_hat, x, mean, logvar): BCE = F.binary_cross_entropy(x_hat, x, reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) return BCE + KLD num_epochs = 20 for epoch in range(num_epochs): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, _ = data inputs = inputs.view(inputs.size(0), -1).to(device) optimizer.zero_grad() x_hat, mean, logvar = model(inputs) loss = loss_function(x_hat, inputs, mean, logvar) loss.backward() optimizer.step() running_loss += loss.item() print('[%d] loss: %.3f' % (epoch + 1, running_loss / len(trainset))) # generate new images with torch.no_grad(): z = torch.randn(10, 20).to(device) samples = model.decode(z).cpu() fig, axs = plt.subplots(1, 10, figsize=(20, 2)) for i in range(10): axs[i].imshow(samples[i].view(28, 28), cmap='gray') axs[i].axis('off') plt.show()
三、VAE模型在图像修复中的应用
VAE模型在图像修复中的应用是利用VAE学习到的潜在空间对图片进行修复。可以将待修复图片编码成潜在空间中的向量,对缺失的部分进行插值,然后解码成新的图片。
下面是使用VAE模型在CelebA数据集上进行图像修复的示例代码:
import torch import torchvision import torchvision.transforms as transforms from PIL import Image # load data transform = transforms.Compose([ transforms.CenterCrop((128, 128)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) dataset = torchvision.datasets.ImageFolder(root='./celeba_train', transform=transform) dataloader = torch.utils.data.DataLoader(dataset, batch_size=24, shuffle=True, num_workers=2) # train model device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = VAE(input_size=3*128*128, hidden_size=1024, latent_size=512).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) def loss_function(x_hat, x, mean, logvar): BCE = F.binary_cross_entropy(x_hat, x, reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) return BCE + KLD num_epochs = 20 for epoch in range(num_epochs): running_loss = 0.0 for i, data in enumerate(dataloader, 0): inputs, _ = data inputs = inputs.view(inputs.size(0), -1).to(device) optimizer.zero_grad() x_hat, mean, logvar = model(inputs) loss = loss_function(x_hat, inputs, mean, logvar) loss.backward() optimizer.step() running_loss += loss.item() print('[%d] loss: %.3f' % (epoch + 1, running_loss / len(dataset))) # image inpainting test_img_path = './test.jpg' img = Image.open(test_img_path) img = transform(img).unsqueeze(0).to(device) img_label = torch.zeros_like(img) img_label[:,:,50:78,60:88] = img[:,:,50:78,60:88] with torch.no_grad(): z, _, _ = model.encode(img_label.view(1, -1)) z[:, 256:] = 0 # set the second half of z to 0 fixed_img = model.decode(z) img = img.cpu().squeeze().numpy().transpose(1,2,0) img_label = img_label.cpu().squeeze().numpy().transpose(1,2,0) fixed_img = fixed_img.cpu().squeeze().numpy().transpose(1,2,0) fig, axs = plt.subplots(1, 3, figsize=(15, 5)) axs[0].imshow(img) axs[0].set_title('original image') axs[0].axis('off') axs[1].imshow(img_label) axs[1].set_title('image with mask') axs[1].axis('off') axs[2].imshow(fixed_img) axs[2].set_title('fixed image') axs[2].axis('off') plt.show()