一、fit_generator简介
在深度学习中,当数据集很大时,不能将所有的数据存到内存中。在这种情况下,我们需要使用fit_generator来逐步加入数据进行训练,而不是一次性全部读入内存。
fit_generator是keras中用来进行数据迭代训练的函数,它是fit函数的增强版,用于处理大规模数据集,尤其是超出内存容量的数据集。可以从多种数据来源获取数据,包括numpy数组和目录中的图像文件。在每个epoch中,使用生成器生成mini-batch数据。
在使用fit_generator时,需要编写一个生成器函数,返回训练集中的一个batch,并指定每个epoch包含多少个batch。对于每个epoch,fit_generator函数会自动调用生成器函数,并使用生成的数据进行前向传递,反向传递,调整权重,继续下一个epoch的训练。
from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'train/',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
model.fit_generator(
train_generator,
steps_per_epoch=2000,
epochs=50)
二、生成器函数的构造
生成器函数是fit_generator的关键组成部分。生成器函数必须返回每个batch,并且batch的大小必须与模型期望的输入尺寸一致。在数据生成过程中,可以应用各种数据增强技术,如旋转、缩放和平移等操作以及不同类型的变化,如对图像进行翻转、色彩抖动、裁剪、随机变换等等。
以图像分类为例,生成器函数通常会将图像处理成Numpy数组,并生成每个batch的输入和输出。下面是生成器函数的一个示例:
def image_generator(x_train, y_train, batch_size):
while True:
for i in range(0, len(x_train), batch_size):
x = x_train[i:i+batch_size]/255.0
y = y_train[i:i+batch_size]
yield x, y
三、数据增强技术
在深度学习模型训练中,数据增强技术可以有效地增加数据样本量、提高模型的泛化性能。
Keras提供了ImageDataGenerator类,提供了多种数据增强的方式。
1. 对图像进行旋转、平移等操作
train_datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
train_generator = train_datagen.flow_from_directory(
'data/train',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
2. 对图像进行色彩抖动
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=30,
brightness_range=[0.5,1.5],
shear_range=0.2,
zoom_range=0.2,
width_shift_range=0.3,
height_shift_range=0.3,
horizontal_flip=True,
vertical_flip=False)
train_generator = train_datagen.flow_from_directory(
'train/',
target_size=(150, 150),
batch_size=32,
class_mode='binary')
四、小结
在大规模数据集中,fit_generator是进行深度学习模型训练的重要工具之一。它可以从多种数据来源获取数据,包括numpy数组和目录中的图像文件。在每个epoch中,使用生成器生成mini-batch数据,并自动生成数据来进行训练。
同时,加入数据增强技术可以有效地增加数据样本量,提高模型的泛化性能。