AI学習用データセットのいろいろな入手方法

 人工知能(AI)による画像分類タスクのチュートリアルでよく利用される、MNISTデータセット(手書き数字)を例にします。データセットはAIが処理しやすいように前処理(サイズを揃えるなど)がされていて、何を表す画像なのか、ラベルやグループ分けされているものもあります。

機械学習ライブラリ(初級編)

前半はPythonを使ってダウンロードする方法、Pythonの復習も兼ねた手順となっています。後半はsklearnを使った入手と、ダウンロード時間を節約するpickleを使った手順を説明しています。

『機械学習ライブラリ(初級編)』に戻る>>

整えられたとてもきれいなデータです

ファイルから準備

ダウンロード

 ファイルは4つです。

train-images-idx3-ubyte.gz学習用イメージデータ60,000件
train-labels-idx1-ubyte.gz学習用正解ラベル60,000件
t10k-images-idx3-ubyte.gzテスト用イメージデータ10,000件
t10k-labels-idx1-ubyte.gzテスト用正解ラベル10,000件

◆カレントディレクトリにダウンロードします。
 (ダウンロード済みの場合は不要)

import urllib.request
import glob
flist = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
for fname in flist:
    urllib.request.urlretrieve('http://yann.lecun.com/exdb/mnist/'+fname, './'+fname)
glob.glob('./*.gz')
['train-images-idx3-ubyte.gz',
 'train-labels-idx1-ubyte.gz',
 't10k-images-idx3-ubyte.gz',
 't10k-labels-idx1-ubyte.gz']
--
4ファイルあります。

データを取得

◆ダウンロードした「gzip」ファイルを読み込みます。

import numpy as np
import gzip

#gzipファイル読み込み
flist = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
data = []
for i in range(4):
    offno = 8 if i%2 else 16
    with gzip.open('./'+flist[i], 'rb') as f:
        data.append(np.frombuffer(f.read(), np.uint8, offset=offno))

x_train = data[0].reshape(-1, 784) #学習用イメージデータ
t_train = data[1]                  #学習用正解ラベル
x_test = data[2].reshape(-1, 784)  #テスト用イメージデータ
t_test = data[3]                   #テスト用正解ラベル

#サイズ確認
print('x_train:'+str(x_train.shape))
print('t_train:'+str(t_train.shape))
print('x_test:'+str(x_test.shape))
print('t_test:'+str(t_test.shape))
x_train:(60000, 784)
t_train:(60000,)
x_test:(10000, 784)
t_test:(10000,)
--
学習用データは60,000件、テスト用データは10,000件。

Google Colaboratory」の場合は、プリインストールデータが利用できます。 『Google ColaboratoryのLinuxコマンドに慣れる』 で紹介しています。

表示する

◆「matplotlib」を使って表示します。

import matplotlib.pyplot as plt 
fig=plt.figure()
ax = []
for i in range(100):
    im = x_train[i].reshape(28,28)
    ax.append(fig.add_subplot(10,10,i+1))
    ax[i].axes.xaxis.set_visible(False)
    ax[i].axes.yaxis.set_visible(False)
    ax[i].imshow(im)
plt.show()
MNISTデータの複数画像表示

一般的なカラー画像ではありません。白黒と同じ2色カラーです。

機械学習ライブラリを使う

sklearnでデータ取得

 MNISTのようなチュートリアルに使う有名なデータセットは、ファイルから準備するよりも簡単に利用する方法があります。sklearnを例に示しますが、フレームワークPytorchやTensorFlowにも同様のものが用意されています。

from sklearn import datasets

mnist = datasets.fetch_openml('mnist_784') #ダウンロードに少し時間がかかる
print('data:'+str(mnist.data.shape))
print('label:'+str(mnist.target.shape))
data:(70000, 784)
label:(70000,)
--
画像データが70,000件、ラベルデータが70,000件です。

まとめてあるので70,000件(60,000+10,000)になっていますが、中身は同じです。

メモリデータの保存

「sklearn」のfetch_openml(上記コード3行目)はメモリにデータを格納します。pythonとのセッションを終了すると、メモリデータはクリアされます。再度実行すればいいのですが、ダウンロードには多少の時間がかかります。pickleを使えば、メモリデータを保存できます。

◆上記「mnist(sklearnのBunch)」のメモリデータを「mnist_data」の名前で保存します。

import pickle, glob

with open('./mnist_data', 'wb') as f:
    pickle.dump(mnist, f)
glob.glob('mnist*')
['mnist_data']
--
カレントディレクトリに「mnist_data」が作成されます。
◆保存したpickleデータを読み込みます。
import pickle

with open('./mnist_data','rb') as f:
    mnist2 = pickle.load(f)
print('data(mnist2):'+str(mnist2.data.shape))
print('label(mnist2):'+str(mnist2.target.shape))
data(mnist2):(70000, 784)
label(mnist2):(70000,)
--
同じく、画像データが70,000件、ラベルデータが70,000件です。

『機械学習ライブラリ(初級編)』に戻る>>

以上