您的位置:

深度学习中的AI绘画技术——探究VAE模型

一、什么是VAE模型

VAE全称为Variational Autoencoder,是一种生成模型。VAE通过将输入数据映射到潜在空间中,实现对样本的压缩和重构,并且通过引入潜在变量来控制生成数据的分布,从而可以生成新的数据样本。

VAE模型的主要特点是使用了变分下界来优化模型,从而让模型在训练过程中更加稳定,同时可以利用VAE学到的潜在空间进行插值、生成多个样本等任务。

下面是使用PyTorch实现VAE模型的示例代码:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class VAE(nn.Module):
        def __init__(self, input_size, hidden_size, latent_size):
            super(VAE, self).__init__()
            self.input_size = input_size
            self.hidden_size = hidden_size
            self.latent_size = latent_size
            
            # encoder
            self.enc_fc1 = nn.Linear(input_size, hidden_size)
            self.enc_fc2_mean = nn.Linear(hidden_size, latent_size)
            self.enc_fc2_logvar = nn.Linear(hidden_size, latent_size)
            
            # decoder
            self.dec_fc1 = nn.Linear(latent_size, hidden_size)
            self.dec_fc2 = nn.Linear(hidden_size, input_size)
            
        def encode(self, x):
            h = F.relu(self.enc_fc1(x))
            mean = self.enc_fc2_mean(h)
            logvar = self.enc_fc2_logvar(h)
            return mean, logvar
        
        def reparameterize(self, mean, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + eps * std
        
        def decode(self, z):
            h = F.relu(self.dec_fc1(z))
            x_hat = torch.sigmoid(self.dec_fc2(h))
            return x_hat
        
        def forward(self, x):
            mean, logvar = self.encode(x)
            z = self.reparameterize(mean, logvar)
            x_hat = self.decode(z)
            return x_hat, mean, logvar

二、VAE模型在图像生成中的应用

VAE模型在图像生成中的应用是在潜在空间中生成新的样本。通常情况下,我们可以使用VAE将输入图片编码成一个低维的向量,然后在潜在空间中随机采样,最后将采样到的向量解码成新的图片。

下面是使用VAE模型在MNIST数据集上进行图片生成的示例代码:

    import torch
    import torchvision
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    
    # load data
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,))])
    
    trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                              shuffle=True, num_workers=2)
    
    
    # train model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = VAE(input_size=784, hidden_size=512, latent_size=20).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    def loss_function(x_hat, x, mean, logvar):
        BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
        return BCE + KLD
    
    num_epochs = 20
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, _ = data
            inputs = inputs.view(inputs.size(0), -1).to(device)
            optimizer.zero_grad()
            x_hat, mean, logvar = model(inputs)
            loss = loss_function(x_hat, inputs, mean, logvar)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
    
        print('[%d] loss: %.3f' %
              (epoch + 1, running_loss / len(trainset)))
        
    # generate new images
    with torch.no_grad():
        z = torch.randn(10, 20).to(device)
        samples = model.decode(z).cpu()
    
    fig, axs = plt.subplots(1, 10, figsize=(20, 2))
    for i in range(10):
        axs[i].imshow(samples[i].view(28, 28), cmap='gray')
        axs[i].axis('off')
    
    plt.show()

三、VAE模型在图像修复中的应用

VAE模型在图像修复中的应用是利用VAE学习到的潜在空间对图片进行修复。可以将待修复图片编码成潜在空间中的向量,对缺失的部分进行插值,然后解码成新的图片。

下面是使用VAE模型在CelebA数据集上进行图像修复的示例代码:

    import torch
    import torchvision
    import torchvision.transforms as transforms
    from PIL import Image
    
    # load data
    transform = transforms.Compose([
            transforms.CenterCrop((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
    
    dataset = torchvision.datasets.ImageFolder(root='./celeba_train', transform=transform)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=24, shuffle=True, num_workers=2)
    
    # train model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = VAE(input_size=3*128*128, hidden_size=1024, latent_size=512).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
    def loss_function(x_hat, x, mean, logvar):
        BCE = F.binary_cross_entropy(x_hat, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
        return BCE + KLD
    
    num_epochs = 20
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(dataloader, 0):
            inputs, _ = data
            inputs = inputs.view(inputs.size(0), -1).to(device)
            optimizer.zero_grad()
            x_hat, mean, logvar = model(inputs)
            loss = loss_function(x_hat, inputs, mean, logvar)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
    
        print('[%d] loss: %.3f' %
              (epoch + 1, running_loss / len(dataset)))
    
    # image inpainting
    test_img_path = './test.jpg'
    
    img = Image.open(test_img_path)
    img = transform(img).unsqueeze(0).to(device)
    
    img_label = torch.zeros_like(img)
    img_label[:,:,50:78,60:88] = img[:,:,50:78,60:88]
    
    with torch.no_grad():
        z, _, _ = model.encode(img_label.view(1, -1))
        z[:, 256:] = 0 # set the second half of z to 0
        
        fixed_img = model.decode(z)
    
    img = img.cpu().squeeze().numpy().transpose(1,2,0)
    img_label = img_label.cpu().squeeze().numpy().transpose(1,2,0)
    fixed_img = fixed_img.cpu().squeeze().numpy().transpose(1,2,0)
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    
    axs[0].imshow(img)
    axs[0].set_title('original image')
    axs[0].axis('off')
    
    axs[1].imshow(img_label)
    axs[1].set_title('image with mask')
    axs[1].axis('off')
    
    axs[2].imshow(fixed_img)
    axs[2].set_title('fixed image')
    axs[2].axis('off')
    
    plt.show()