一、GAN概述
GAN(Generative Adversarial Networks)是一种生成模型,由生成器和判别器两大部分组成,目的是学习真实数据分布,并且从噪声中生成与真实数据相似的样本。
二、GAN的损失函数
GAN的损失函数包含两个部分:生成器损失和判别器损失。
1. 生成器损失
生成器的任务是生成与真实数据相似的数据样本,因此其损失可以定义为判别器无法正确辨别生成数据的概率,即:
G_loss = -log(D(G(z)))
其中,G(z)表示噪声z经过生成器G生成的样本,D表示判别器,G_loss越小,则生成样本越接近真实数据。
2. 判别器损失
判别器的任务是辨别真实数据和生成数据,因此其损失可以定义为正确分类真实样本的概率和正确分类生成样本的概率的平均数,即:
D_loss = -log(D(x)) -log(1-D(G(z)))
其中,x表示真实数据,D(x)表示判别器将真实数据判为真实数据的概率,D(G(z))表示判别器将生成数据判为真实数据的概率,D_loss越小,则判别器越能够准确地分辨真实数据和生成数据。
三、GAN损失函数的训练过程
GAN的训练过程是博弈过程,即生成器和判别器不断地相互博弈,训练流程如下:
1. 初始化参数
生成器和判别器都需要初始化参数,生成器的参数通常以随机噪声z作为输入,输出与真实数据相似的样本数据,判别器的参数通常以真实数据或者生成器生成的数据作为输入,输出为0(真实数据)或1(生成数据)。
2. 训练判别器
首先固定生成器的参数,训练判别器的参数,让判别器能够准确地分辨真实数据和生成数据,即最小化判别器损失函数:
min(D_loss)
3. 训练生成器
接着固定判别器参数,训练生成器的参数,让生成器生成与真实数据相似的样本数据,即最小化生成器损失函数:
min(G_loss)
4. 不断交替训练
在训练过程中,生成器和判别器不断交替训练,直到生成器生成的样本无法被判别器辨别为止。
四、完整代码示例
import tensorflow as tf from tensorflow import keras import numpy as np # 定义生成器网络 def make_generator_model(): model = keras.Sequential() model.add(keras.layers.Dense(256, input_shape=(100,), use_bias=False)) model.add(keras.layers.BatchNormalization()) model.add(keras.layers.LeakyReLU()) model.add(keras.layers.Dense(512, use_bias=False)) model.add(keras.layers.BatchNormalization()) model.add(keras.layers.LeakyReLU()) model.add(keras.layers.Dense(28*28*1, use_bias=False, activation='tanh')) model.add(keras.layers.Reshape((28, 28, 1))) return model # 定义判别器网络 def make_discriminator_model(): model = keras.Sequential() model.add(keras.layers.Flatten(input_shape=(28,28,1))) model.add(keras.layers.Dense(512)) model.add(keras.layers.LeakyReLU()) model.add(keras.layers.Dense(256)) model.add(keras.layers.LeakyReLU()) model.add(keras.layers.Dense(1, activation='sigmoid')) return model # 定义损失函数 cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True) # 判别器损失函数 def discriminator_loss(real_output, fake_output): real_loss = cross_entropy(tf.ones_like(real_output), real_output) fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output) total_loss = real_loss + fake_loss return total_loss # 生成器损失函数 def generator_loss(fake_output): return cross_entropy(tf.ones_like(fake_output), fake_output) # 定义优化器 generator_optimizer = keras.optimizers.Adam(1e-4) discriminator_optimizer = keras.optimizers.Adam(1e-4) # 定义训练步骤 @tf.function def train_step(images): noise = tf.random.normal([BATCH_SIZE, 100]) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True) real_output = discriminator(images, training=True) fake_output = discriminator(generated_images, training=True) gen_loss = generator_loss(fake_output) disc_loss = discriminator_loss(real_output, fake_output) gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) # 定义训练过程 def train(dataset, epochs): for epoch in range(epochs): for image_batch in dataset: train_step(image_batch) # 加载数据集 (train_images, train_labels), (_, _) = keras.datasets.mnist.load_data() train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') train_images = (train_images - 127.5) / 127.5 # 将像素范围缩放到[-1, 1]之间 BUFFER_SIZE = 60000 BATCH_SIZE = 256 train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) # 构造生成器和判别器 generator = make_generator_model() discriminator = make_discriminator_model() # 定义训练参数 EPOCHS = 100 noise_dim = 100 num_examples_to_generate = 16 # 每轮生成的样本数量 # 开始训练 train(train_dataset, EPOCHS) # 生成样本 noise = tf.random.normal([num_examples_to_generate, noise_dim]) generated_images = generator(noise, training=False) # 展示生成的样本 import matplotlib.pyplot as plt fig = plt.figure(figsize=(4,4)) for i in range(generated_images.shape[0]): plt.subplot(4, 4, i+1) plt.imshow((generated_images[i, :, :, 0] + 1)/2, cmap='gray') plt.axis('off') plt.show()