您的位置:

fit_generator详解

一、fit_generator简介

在深度学习中,当数据集很大时,不能将所有的数据存到内存中。在这种情况下,我们需要使用fit_generator来逐步加入数据进行训练,而不是一次性全部读入内存。

fit_generator是keras中用来进行数据迭代训练的函数,它是fit函数的增强版,用于处理大规模数据集,尤其是超出内存容量的数据集。可以从多种数据来源获取数据,包括numpy数组和目录中的图像文件。在每个epoch中,使用生成器生成mini-batch数据。

在使用fit_generator时,需要编写一个生成器函数,返回训练集中的一个batch,并指定每个epoch包含多少个batch。对于每个epoch,fit_generator函数会自动调用生成器函数,并使用生成的数据进行前向传递,反向传递,调整权重,继续下一个epoch的训练。


from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
    'train/',
    target_size=(150, 150),
    batch_size=32,
    class_mode='binary')

model.fit_generator(
    train_generator,
    steps_per_epoch=2000,
    epochs=50)

二、生成器函数的构造

生成器函数是fit_generator的关键组成部分。生成器函数必须返回每个batch,并且batch的大小必须与模型期望的输入尺寸一致。在数据生成过程中,可以应用各种数据增强技术,如旋转、缩放和平移等操作以及不同类型的变化,如对图像进行翻转、色彩抖动、裁剪、随机变换等等。

以图像分类为例,生成器函数通常会将图像处理成Numpy数组,并生成每个batch的输入和输出。下面是生成器函数的一个示例:


def image_generator(x_train, y_train, batch_size):
    while True:
        for i in range(0, len(x_train), batch_size):
            x = x_train[i:i+batch_size]/255.0
            y = y_train[i:i+batch_size]
            yield x, y

三、数据增强技术

在深度学习模型训练中,数据增强技术可以有效地增加数据样本量、提高模型的泛化性能。

Keras提供了ImageDataGenerator类,提供了多种数据增强的方式。

1. 对图像进行旋转、平移等操作


train_datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest')

train_generator = train_datagen.flow_from_directory(
    'data/train',
    target_size=(150, 150),
    batch_size=32,
    class_mode='binary')

2. 对图像进行色彩抖动


train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    brightness_range=[0.5,1.5],
    shear_range=0.2,
    zoom_range=0.2,
    width_shift_range=0.3,
    height_shift_range=0.3,
    horizontal_flip=True,
    vertical_flip=False)

train_generator = train_datagen.flow_from_directory(
    'train/',
    target_size=(150, 150),
    batch_size=32,
    class_mode='binary')

四、小结

在大规模数据集中,fit_generator是进行深度学习模型训练的重要工具之一。它可以从多种数据来源获取数据,包括numpy数组和目录中的图像文件。在每个epoch中,使用生成器生成mini-batch数据,并自动生成数据来进行训练。

同时,加入数据增强技术可以有效地增加数据样本量,提高模型的泛化性能。