如何用python解析mnist图片

MNIST 数据集是一个手写数字识别训练数据集,来自美国国家标准与技术研究所National Institute of Standards and Technology (NIST)。训练集 (training set) 由来自 250 个不同人手写的数字构成,其中 50% 是高中学生,50% 来自人口普查局 (the Census Bureau) 的工作人员。测试集(test set) 也是同样比例的手写数字数据。

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

  • Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)
  • Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)
  • Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)
  • Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

因数据集是按特殊格式压缩存储的,所以如果在训练或应用模型时想看一下原图就会比较困难,本文讲解如果反向解析这个数据集。

一、mnist数据结构

可以看出在train-images.idx3-ubyte中,第一个数为32位的整数(魔数,图片类型的数),第二个数为32位的整数(图片的个数),第三和第四个也是32为的整数(分别代表图片的行数和列数),接下来的都是一个字节的无符号数(即像素,值域为0~255),因此,我们只需要依次获取魔数和图片的个数,然后获取图片的长和宽,最后逐个像素读取就可以了。

如何使用python解析数据呢? 首先需要安装python的图形处理库PIL,这个库支持像素级别的图像处理,对于学习数字图像处理有很大的帮助。安装完成之后,就可以进行图像的解析了。

sudo pip install Pillow   #PIL python image library

首先打开文件,然后分别读取魔数,图片个数,以及行数和列数,在struct中,可以看到,使用了’>IIII’,这是什么意思呢?意思就是使用大端规则,读取四个整形数(Integer),如果要读取一个字节,则可以用’>B’(当然,这里用没用大端规则都是一样的,因此只有两个或两个以上的字节才有用)。

什么是大端规则呢?不懂的可以百度一下,这个不再赘述(http://baike.baidu.com/link?url=Bgg8b0vRr3b_SeGyOl8U4DmAbIQT9swGuNtD_21ctEI_NliqsQ-mKF73YT90EILF2EQy50mEua_M4z6Cma3rmK)

然后对于每张图片,先创建一张空白的图片,其中的’L’代表这张图片是灰度图,最后逐个像素读取,然后写进空白图片里,最后保存图片,就可以了

二、mnist标签数据结构

可以发现,与上面的非常相似,只不过这里每一个字节变成了标签而已(标签大小为0~9)

三、解析mnist图片和标签数据

好了,通过上述讲解,最后我们可以通过python将mnist解析出来了,看一下效果:

解析出的图片数据:

解析出的label:

程序源代码如下:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from PIL import Image
import struct
#import matplotlib.pyplot as plt
#import numpy as np

# 读取mnist图片
def read_image(filename):
  f = open(filename, 'rb')
  index = 0
  buf = f.read()
  f.close()

  numRows = 28
  numColumns = 28
  # '>IIII'是指使用大端法读取4个unsinged int32
  magic, numImages, numRows, numColumns = struct.unpack_from('>IIII' , buf , index)
  index += struct.calcsize('>IIII')

  # 按像素还原图片
  for i in xrange(numImages):
    image = Image.new('L', (numRows, numColumns))
    for x in xrange(numRows):
      for y in xrange(numColumns):
        image.putpixel((y, x), int(struct.unpack_from('>B', buf, index)[0]))
        index += struct.calcsize('>B')

    print 'save ' + str(i) + 'image'
    image.save('image/' + str(i) + '.png')

# 读取mnist label
def read_label(filename, saveFilename):
  f = open(filename, 'rb')
  index = 0
  buf = f.read()
  f.close()

  magic, labels = struct.unpack_from('>II' , buf , index)
  index += struct.calcsize('>II')
  
  #labelArr = [0] * labels
  labelArr = [0] * 2000
  #for x in xrange(labels):
  for x in xrange(2000):
    labelArr[x] = int(struct.unpack_from('>B', buf, index)[0])
    index += struct.calcsize('>B')
  save = open(saveFilename, 'w')
  save.write(','.join(map(lambda x: str(x), labelArr)))
  save.write('\n')
  save.close()
  print 'save labels success'


if __name__ == '__main__':
  #具体dataset数据和label文件请从官方或通过paddle示例获取
  #mnist_path = '/home/work/.cache/paddle/dataset/mnist/'
  mnist_path = 'mnist/'
  # 还原mnist图片
  try:
    read_image(mnist_path + 't10k-images-idx3-ubyte.gz')
  except:
    pass
  # 还原mnist标签
  read_label(mnist_path + 't10k-labels-idx1-ubyte.gz', 'label.txt')

yan 2018.4.16 22:19

参考:

https://blog.csdn.net/simple_the_best/article/details/75267863

http://yann.lecun.com/exdb/mnist/

 

发表评论

电子邮件地址不会被公开。