如何用python解析mnist图片

MNIST 数据集是一个手写数字识别训练数据集,来自美国国家标准与技术研究所National Institute of Standards and Technology (NIST)。训练集 (training set) 由来自 250 个不同人手写的数字构成,其中 50% 是高中学生,50% 来自人口普查局 (the Census Bureau) 的工作人员。测试集(test set) 也是同样比例的手写数字数据。

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

因数据集是按特殊格式压缩存储的,所以如果在训练或应用模型时想看一下原图就会比较困难,本文讲解如果反向解析这个数据集。

mnist image图片集结构

可以看出在train-images.idx3-ubyte中,第一个数为32位的整数(魔数,图片类型的数),第二个数为32位的整数(图片的个数),第三和第四个也是32为的整数(分别代表图片的行数和列数),接下来的都是一个字节的无符号数(即像素,值域为0~255),因此,我们只需要依次获取魔数和图片的个数,然后获取图片的长和宽,最后逐个像素读取就可以了。

如何使用python解析数据呢? 首先需要安装python的图形处理库PIL,这个库支持像素级别的图像处理,对于学习数字图像处理有很大的帮助。安装完成之后,就可以进行图像的解析了。

sudo pip install Pillow   #PIL python image library

首先打开文件,然后分别读取魔数,图片个数,以及行数和列数,在struct中,可以看到,使用了’>IIII’,这是什么意思呢?意思就是使用大端规则,读取四个整形数(Integer),如果要读取一个字节,则可以用’>B’(当然,这里用没用大端规则都是一样的,因此只有两个或两个以上的字节才有用)。

什么是大端规则呢?不懂的可以百度一下,这个不再赘述(http://baike.baidu.com/link?url=Bgg8b0vRr3b_SeGyOl8U4DmAbIQT9swGuNtD_21ctEI_NliqsQ-mKF73YT90EILF2EQy50mEua_M4z6Cma3rmK)

然后对于每张图片,先创建一张空白的图片,其中的’L’代表这张图片是灰度图,最后逐个像素读取,然后写进空白图片里,最后保存图片,就可以了

mnist label标签集结构

可以发现,与上面的非常相似,只不过这里每一个字节变成了标签而已(标签大小为0~9)

读取图片和标签数据

根据上述结构的分析,我们可以通过python将mnist的图片和对应label解析出来:

程序源代码如下:

vim read_mnist.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# 读取mnist数据集的图片和label,并生成以序号+label命名的图片
# cmd: nohup python read_mnist.py >log/read_mnist.log &
import os
import random
import platform
import numpy
import subprocess
from PIL import Image
import utils

#mnist训练集目录
#mnist_path = '/home/work/paddle/sample/recognize_digits/train/data/mnist/'
mnist_path = '/home/work/.cache/paddle/dataset/mnist/'
train_image   = mnist_path + 'train-images-idx3-ubyte.gz'
train_label   = mnist_path + 'train-labels-idx1-ubyte.gz'
test_image    = mnist_path + 't10k-images-idx3-ubyte.gz'
test_label    = mnist_path + 't10k-labels-idx1-ubyte.gz'

#读取出的图片存放位置
output_path = '/home/work/paddle/sample/recognize_digits/train/data/'


def reader_mnist(image_filename,label_filename,buffer_size=200,path='train'):
    if platform.system()=='Linux':
        zcat_cmd = 'zcat'
    else:
        raise NotImplementedError("This program is suported on Linux,\
                                  but your platform is" + platform.system())

    # 读取mnist图片集
    sub_img = subprocess.Popen([zcat_cmd, image_filename], stdout = subprocess.PIPE)
    sub_img.stdout.read(16) # 跳过前16个magic字节

    # 读取mnist标签集
    sub_lab = subprocess.Popen([zcat_cmd, label_filename], stdout = subprocess.PIPE)
    sub_lab.stdout.read(8)  # 跳过前8个magic字节

    try:
        id = 0 #图片集序号
        while True:         #前面使用try,故若再读取过程中遇到结束则会退出
            # 批量读取label,每个label占1个字节
            labels = numpy.fromfile(sub_lab.stdout,'ubyte',count=buffer_size).astype("int")
            if labels.size != buffer_size:
                break
            # 批量读取image,每个image占28*28个字节,并转换为28*28的二维float数组
            images = numpy.fromfile(sub_img.stdout,'ubyte',count=buffer_size * 28 * 28).reshape(buffer_size, 28, 28).astype("float32")
            for i in xrange(buffer_size):
                id += 1
                img = images[i]
		num = labels[i]
                #print img
                #print num
                #创建新28*28图片对象
                image = Image.new('L', (28, 28))
                for x in xrange(28):
                    for y in xrange(28):
		        #print img[x][y]
                        image.putpixel((y, x), int(img[x][y])) #按像素写入

                #保存图片(序号-label.png)
		utils.mkdir(output_path + path)
		save_file = output_path + path + '/' + str(id) + '-' + str(num) + '.png'
		image.save(save_file) 
		print save_file 
                #break
	    #break

    finally: 
	#结束读取进程
        sub_img.terminate()
        sub_lab.terminate()


if __name__ == '__main__':
    #读取训练集
    reader_mnist(train_image,train_label,buffer_size=200,path='train')
    #读取测试集
    reader_mnist(test_image,test_label,buffer_size=200,path='test')

解析出的图片:(文件名:序号-label.png

yan 2018.12.3 22:19

参考:

http://yann.lecun.com/exdb/mnist/

https://www.liaoxuefeng.com/wiki/001374738125095c955c1e6d8bb493182103fac9270762a000/00140767171357714f87a053a824ffd811d98a83b58ec13000

发表评论

电子邮件地址不会被公开。