MNIST 数据集是一个手写数字识别训练数据集,来自美国国家标准与技术研究所National Institute of Standards and Technology (NIST)。训练集 (training set) 由来自 250 个不同人手写的数字构成,其中 50% 是高中学生,50% 来自人口普查局 (the Census Bureau) 的工作人员。测试集(test set) 也是同样比例的手写数字数据。
MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:
- Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
- Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
- Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
- Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)
因数据集是按特殊格式压缩存储的,所以如果在训练或应用模型时想看一下原图就会比较困难,本文讲解如果反向解析这个数据集。
mnist image图片集结构
可以看出在train-images.idx3-ubyte中,第一个数为32位的整数(魔数,图片类型的数),第二个数为32位的整数(图片的个数),第三和第四个也是32为的整数(分别代表图片的行数和列数),接下来的都是一个字节的无符号数(即像素,值域为0~255),因此,我们只需要依次获取魔数和图片的个数,然后获取图片的长和宽,最后逐个像素读取就可以了。
如何使用python解析数据呢? 首先需要安装python的图形处理库PIL,这个库支持像素级别的图像处理,对于学习数字图像处理有很大的帮助。安装完成之后,就可以进行图像的解析了。
sudo pip install Pillow #PIL python image library
首先打开文件,然后分别读取魔数,图片个数,以及行数和列数,在struct中,可以看到,使用了’>IIII’,这是什么意思呢?意思就是使用大端规则,读取四个整形数(Integer),如果要读取一个字节,则可以用’>B’(当然,这里用没用大端规则都是一样的,因此只有两个或两个以上的字节才有用)。
什么是大端规则呢?不懂的可以百度一下,这个不再赘述(http://baike.baidu.com/link?url=Bgg8b0vRr3b_SeGyOl8U4DmAbIQT9swGuNtD_21ctEI_NliqsQ-mKF73YT90EILF2EQy50mEua_M4z6Cma3rmK)
然后对于每张图片,先创建一张空白的图片,其中的’L’代表这张图片是灰度图,最后逐个像素读取,然后写进空白图片里,最后保存图片,就可以了
mnist label标签集结构
可以发现,与上面的非常相似,只不过这里每一个字节变成了标签而已(标签大小为0~9)
读取图片和标签数据
根据上述结构的分析,我们可以通过python将mnist的图片和对应label解析出来:
程序源代码如下:
vim read_mnist.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# 读取mnist数据集的图片和label,并生成以序号+label命名的图片
# cmd: nohup python read_mnist.py >log/read_mnist.log &
import os
import random
import platform
import numpy
import subprocess
from PIL import Image
import utils
#mnist训练集目录
#mnist_path = '/home/work/paddle/sample/recognize_digits/train/data/mnist/'
mnist_path = '/home/work/.cache/paddle/dataset/mnist/'
train_image = mnist_path + 'train-images-idx3-ubyte.gz'
train_label = mnist_path + 'train-labels-idx1-ubyte.gz'
test_image = mnist_path + 't10k-images-idx3-ubyte.gz'
test_label = mnist_path + 't10k-labels-idx1-ubyte.gz'
#读取出的图片存放位置
output_path = '/home/work/paddle/sample/recognize_digits/train/data/'
def reader_mnist(image_filename,label_filename,buffer_size=200,path='train'):
if platform.system()=='Linux':
zcat_cmd = 'zcat'
else:
raise NotImplementedError("This program is suported on Linux,\
but your platform is" + platform.system())
# 读取mnist图片集
sub_img = subprocess.Popen([zcat_cmd, image_filename], stdout = subprocess.PIPE)
sub_img.stdout.read(16) # 跳过前16个magic字节
# 读取mnist标签集
sub_lab = subprocess.Popen([zcat_cmd, label_filename], stdout = subprocess.PIPE)
sub_lab.stdout.read(8) # 跳过前8个magic字节
try:
id = 0 #图片集序号
while True: #前面使用try,故若再读取过程中遇到结束则会退出
# 批量读取label,每个label占1个字节
labels = numpy.fromfile(sub_lab.stdout,'ubyte',count=buffer_size).astype("int")
if labels.size != buffer_size:
break
# 批量读取image,每个image占28*28个字节,并转换为28*28的二维float数组
images = numpy.fromfile(sub_img.stdout,'ubyte',count=buffer_size * 28 * 28).reshape(buffer_size, 28, 28).astype("float32")
for i in xrange(buffer_size):
id += 1
img = images[i]
num = labels[i]
#print img
#print num
#创建新28*28图片对象
image = Image.new('L', (28, 28))
for x in xrange(28):
for y in xrange(28):
#print img[x][y]
image.putpixel((y, x), int(img[x][y])) #按像素写入
#保存图片(序号-label.png)
utils.mkdir(output_path + path)
save_file = output_path + path + '/' + str(id) + '-' + str(num) + '.png'
image.save(save_file)
print save_file
#break
#break
finally:
#结束读取进程
sub_img.terminate()
sub_lab.terminate()
if __name__ == '__main__':
#读取训练集
reader_mnist(train_image,train_label,buffer_size=200,path='train')
#读取测试集
reader_mnist(test_image,test_label,buffer_size=200,path='test')
解析出的图片:(文件名:序号-label.png)
yan 2018.12.3 22:19
参考:
http://yann.lecun.com/exdb/mnist/
https://www.liaoxuefeng.com/wiki/001374738125095c955c1e6d8bb493182103fac9270762a000/00140767171357714f87a053a824ffd811d98a83b58ec13000