您的位置:

手把手教你使用Python读取MNIST数据集

一、MNIST数据集介绍

MNIST数据集是深度学习和机器学习领域中非常著名的一个数据集,它包含了大量的手写数字图像,被广泛地用于各种分类算法的评测和比较。MNIST数据集一共包含70000张手写数字的图片,每张图片的大小为28 x 28像素。其中,前60000张图片被用作训练集,后10000张图片被用作测试集。训练集和测试集中的标签均为0至9的数字。

二、MNIST数据集的下载

要想使用Python读取MNIST数据集,首先需要下载数据集。在下面的代码片段中,我们将展示如何使用Python代码下载MNIST数据集:

from tensorflow.examples.tutorials.mnist import input_data
 
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

以上代码中,我们使用了TensorFlow自带的input_data脚本来下载MNIST数据集。该脚本将训练集、验证集、测试集分别存储到MNIST_data文件夹下,并将标签转换为one-hot编码形式,以便于模型训练时使用。

三、读取MNIST数据集

有了MNIST数据集的下载,接下来我们就可以开始读取数据了。在下面的示例代码中,我们将展示如何使用Python代码读取MNIST数据集的训练集和测试集,以及如何显示其中一张图片:

import numpy as np
import matplotlib.pyplot as plt
 
# 载入数据集
train = mnist.train.images
train_labels = mnist.train.labels
test = mnist.test.images
test_labels = mnist.test.labels
 
# 显示数据集中的一张图片
plt.imshow(np.reshape(train[0], [28, 28]), cmap='gray')
plt.show()

以上代码中,我们使用了NumPy和matplotlib两个库,分别用于将图像数据转换为矩阵,以及将矩阵显示为图像。其中,train和test分别为训练集和测试集的图像数据,train_labels和test_labels分别为训练集和测试集的标签。我们使用plt.imshow函数来显示数据集中的一张图片,该函数中的cmap参数用于指定灰度图像的颜色映射方式。

四、MNIST数据集的预处理

在进行模型训练之前,我们需要对MNIST数据集进行一些预处理操作。在本小节中,我们将介绍两个常用的预处理方法,分别是数据归一化和数据增强。

1. 数据归一化

数据归一化是指将数据的取值范围缩放到一定区间内的操作。在深度学习和机器学习中,数据归一化是一项非常重要的预处理方法,常用的数据归一化方式有两种:最小-最大规范化和标准化。

最小-最大规范化将数据缩放到[0,1]区间内,其公式如下:

\begin{aligned} {x}'=\frac{x-min}{max-min} \end{aligned}

其中,min是数据集中的最小值,max是数据集中的最大值。在下面的代码中,我们将展示如何使用最小-最大规范化对MNIST数据集进行归一化处理:

# 最小-最大规范化
def min_max_normalize(data):
    min_val = np.min(data)
    max_val = np.max(data)
    return (data - min_val) / (max_val - min_val)
 
# 对数据集进行归一化
train = min_max_normalize(train)
test = min_max_normalize(test)

以上代码中,我们定义了一个min_max_normalize函数,用于对输入数据进行最小-最大规范化处理。在该函数中,我们首先计算出数据集中的最小值和最大值,然后将其应用到数据集中,以得到归一化后的数据。

2. 数据增强

数据增强是指利用一些变换操作扩充原始数据集的操作,旨在提高模型的泛化能力。在数字识别任务中,常用的数据增强方式有平移、旋转、缩放等。

对于MNIST数据集,我们可以利用一些平移和旋转操作生成新的图像。在下面的代码中,我们将展示如何使用Python代码进行平移操作和旋转操作:

from scipy.ndimage.interpolation import shift, rotate
 
# 平移变换
def shift_image(image, sx, sy):
    return shift(image.reshape((28, 28)), [sy, sx], cval=0.).reshape([-1])
 
# 旋转变换
def rotate_image(image, angle):
    return rotate(image.reshape((28, 28)), angle, cval=0., reshape=False).reshape([-1])
 
# 生成平移后的图像
shifted_imgs = [shift_image(train[i], 5, 5) for i in range(train.shape[0])]
 
# 生成旋转后的图像
rotated_imgs = [rotate_image(train[i], 30) for i in range(train.shape[0])]

以上代码中,我们使用了SciPy中的ndimage模块中的shift和rotate函数,分别用于进行平移变换和旋转变换。在shift_image函数中,我们通过调用ndimage.shift函数来实现平移变化,其中sx和sy分别指定x轴和y轴的平移距离。在rotate_image函数中,我们使用ndimage.rotate函数来进行旋转变化,其中angle指定旋转角度。

五、小结

本文主要介绍了如何使用Python读取MNIST数据集,并对数据集进行了预处理操作。虽然MNIST数据集已经是一个非常老旧的数据集了,但它仍广泛地应用于各种深度学习和机器学习算法的评测和比较中。在使用MNIST数据集进行数字识别任务时,我们需要注意数据的归一化和增强等预处理操作,以提高模型的准确率和泛化能力。