您的位置:

CIFAR100数据集下载及介绍

一、CIFAR10数据集下载

import urllib.request
import tarfile
import os

url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
filepath = "cifar-10-python.tar.gz"
if not os.path.isfile(filepath):
    result = urllib.request.urlretrieve(url, filepath)
    print('downloaded:', result)
     
if not os.path.exists("cifar-10-batches-py"):
    tfile = tarfile.open("cifar-10-python.tar.gz", 'r:gz')
    result = tfile.extractall('.')
    print('extracted:', result)
else:
    print('Data has existed.')

CIFAR-10(Canadian Institute For Advanced Research)是一个经典的图像分类数据集,共有10个类别,每个类别有6000张32*32的彩色图片,其中50000张作为训练集,10000张作为测试集。

二、CIFAR10数据集读取

import pickle
import numpy as np
 
def load_CIFAR_batch(filename):
    with open(filename, 'rb') as f:
        datadict = pickle.load(f, encoding='latin1')
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
        Y = np.array(Y)
        return X, Y
 
def load_CIFAR10(ROOT):
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join(ROOT, 'data_batch_%d' % (b,))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte

CIFAR-10数据集中包含5个批量的训练数据,每个批量大小为10000,测试数据包含一个批量,大小也为10000。以上代码可以读取CIFAR-10数据集并将其转化为易于处理的numpy数组。

三、CIFAR10数据集介绍

import numpy as np
import matplotlib.pyplot as plt
 
def visualize_CIFAR10_data(X_train, y_train, classes, samples_per_class=7):
    num_classes = len(classes)
    for y, cls in enumerate(classes):
        idxes = np.flatnonzero(y_train == y)
        idxes = np.random.choice(idxes, samples_per_class, replace=False)
        for i, idx in enumerate(idxes):
            plt_idx = i * num_classes + y + 1
            plt.subplot(samples_per_class, num_classes, plt_idx)
            plt.imshow(X_train[idx].astype('uint8'))
            plt.axis('off')
            if i == 0:
                plt.title(cls)
    plt.show()

def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000):
    cifar10_dir = '../datasets/cifar-10-batches-py'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
         
    mask = range(num_training, num_training + num_validation)
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = range(num_training)
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = range(num_test)
    X_test = X_test[mask]
    y_test = y_test[mask]
         
    mean_image = np.mean(X_train, axis=0)
    X_train -= mean_image
    X_val -= mean_image
    X_test -= mean_image
     
    return X_train, y_train, X_val, y_val, X_test, y_test

CIFAR-10数据集是由TensorFlow提供的,其中包括训练、测试、验证集,每张图片都有一个标签,总共10个类别。以上代码用于可视化数据集以及获取数据集。

四、CIFAR100数据集下载

import urllib.request
import tarfile
import os

url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filepath = "cifar-100-python.tar.gz"
if not os.path.isfile(filepath):
    result = urllib.request.urlretrieve(url, filepath)
    print('downloaded:', result)
     
if not os.path.exists("cifar-100-python"):
    tfile = tarfile.open("cifar-100-python.tar.gz", 'r:gz')
    result = tfile.extractall('.')
    print('extracted:', result)
else:
    print('Data has existed.')

CIFAR-100数据集共有100个类别,每个类别有600张32*32的彩色图片,其中50000张作为训练集,10000张作为测试集。CIFAR-100数据集的组织方式与CIFAR-10相似,但是它有更多的类别,更多的样本。运行以上代码可以下载CIFAR-100数据集。

五、CIFAR100数据集介绍

def load_CIFAR100(filename):
    with open(filename, 'rb') as f:
        datadict = pickle.load(f, encoding='latin1')
        X = datadict['data']
        Y = datadict['fine_labels']
        X = X.reshape(50000, 3, 32, 32).transpose(0,2,3,1).astype("float")
        Y = np.array(Y)
        return X, Y

def visualize_CIFAR100_data(X_train, y_train, classes, samples_per_class=7):
    num_classes = len(classes)
    for y, cls in enumerate(classes):
        idxes = np.flatnonzero(y_train == y)
        idxes = np.random.choice(idxes, samples_per_class, replace=False)
        for i, idx in enumerate(idxes):
            plt_idx = i * num_classes + y + 1
            plt.subplot(samples_per_class, num_classes, plt_idx)
            plt.imshow(X_train[idx].astype('uint8'))
            plt.axis('off')
            if i == 0:
                plt.title(cls)
    plt.show()

def get_CIFAR100_data(num_training=49000, num_validation=1000, num_test=10000):
    cifar100_dir = '../datasets/cifar-100-python'
    X_train, y_train = load_CIFAR100(os.path.join(cifar100_dir, 'train'))
    X_test, y_test = load_CIFAR100(os.path.join(cifar100_dir, 'test'))
    classes = ['beaver', 'dolphin', 'otter', 'seal', 'whale','aquarium fish', 'flatfish', 'ray', 'shark', 'trout','orchids', 'poppies', 'roses', 'sunflowers', 'tulips','bottles', 'bowls', 'cans', 'cups', 'plates','apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers','clock', 'computer keyboard', 'lamp', 'telephone', 'television','bed', 'chair', 'couch', 'table', 'wardrobe','bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach','bear', 'leopard', 'lion', 'tiger', 'wolf','bridge', 'castle', 'house', 'road', 'skyscraper','cloud', 'forest', 'mountain', 'plain', 'sea','camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo','fox', 'porcupine', 'possum', 'raccoon', 'skunk','crab', 'lobster', 'snail', 'spider', 'worm','baby', 'boy', 'girl', 'man', 'woman','crocodile', 'dinosaur', 'lizard', 'snake', 'turtle','hamster', 'mouse', 'rabbit', 'shrew', 'squirrel','maple', 'oak', 'palm', 'pine', 'willow','bicycle', 'bus', 'motorcycle', 'pickup truck', 'train','lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor']
    num_classes = len(classes)
    mask = range(num_training, num_training + num_validation)
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = range(num_training)
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = range(num_test)
    X_test = X_test[mask]
    y_test = y_test[mask]
    mean_image = np.mean(X_train, axis=0)
    X_train -= mean_image
    X_val -= mean_image
    X_test -= mean_image
    return X_train, y_train, X_val, y_val, X_test, y_test, classes

CIFAR-100数据集是一个更加复杂的数据集,它有更多的类别,更多的样本。以上代码也用于可视化数据集以及获取数据集。

六、CIFAR100数据集大小

CIFAR-100数据集大小为169MB,可以从以下链接中下载:https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz

七、CIFAR10数据集格式

CIFAR-10数据集中的每个批量都是一个Python Pickle字典,其中包含以下键:

data: 一个10000 * 3072的numpy数组,第一个维度是图像的索引,第二个维度是展平的图像像素值,该数组中的值在0到255之间。

labels: 由大小为10000的1D列表组成的一个长度为10000的numpy数组,其中每个元素是一个类别ID。

八、CIFAR100数据集准确率

CIFAR-100的分类准确率通常在70%到75%之间,具体取决于所用的算法和架构。与CIFAR-10相比,它是更加挑战性的,但是它也是一个非常好的机器学习基准测试集。