読者です 読者をやめる 読者になる 読者になる

walkingmask’s development log

IT系の情報などを適当に書いていきます

MENU

TensorFlowで学習済みモデルを使用する(Deep MNIST for Expertsの応用)

Deep MNIST for Experts(TensorFlow Tutorial)を応用したものを記録.

学習データやカーネルを可視化した話はこちら.
walkingmask.hatenablog.com

Deep MNIST for Experts

Deep MNIST for Expertsは,TensorFlowのチュートリアルで,MNISTという手書き文字数字を認識するCNN(畳み込みニューラルネットワーク)をサクッと実装するもの.20000回学習したNNはテストデータの認識率が99.2%にもなるすごいやつ.

学習済みモデルを使ってみる

数時間にも及ぶ学習を終えて,99.2の数字を目にした時「オオおおおお!」と感動できるチュートリアルだが,ちょっと実感が薄い.そこで,学習したこのモデルを使って,自分で作った手書き数字を認識してくれるかテストしてみる.

自作テストデータ

まずは,自作テスト画像の用意.みんな大好きピクセアララー(Pixlr)を使う.

Photo editor online - Pixlr.com

editorを起動して,キャンバスを28×28に設定.そのままでは画像がちっちゃくて見にくいので拡大して,pencilツールでType: Plainで適当な数字を描く.出来上がった画像がこちら(ちっさい).

f:id:walkingmask:20160827030230p:plain

これを学習済みモデルに読み込ませて,「2」と認識してもらうことを目標とする.

コードの記述

必要なコードは

  • モデルの保存/読み込み
  • 新しい画像データの入力

今回参考にさせてもらったWebページがこちら

qiita.com

これをもとに,完成したプログラムがこちら

https://github.com/WalkingMask/tMNIST/blob/master/src/saver/saver.py

この中で重要な部分は,

saver = tf.train.Saver()

ckpt = tf.train.get_checkpoint_state('./')

if ckpt:
  last_model = ckpt.model_checkpoint_path
  print "load " + last_model
  saver.restore(sess, last_model)

  from PIL import Image
  import numpy as np

  new_img = Image.open('./new_data_2.png').convert('L')
  new_img = 1.0 - np.asarray(new_img, dtype="float32") / 255
  new_img = new_img.reshape((1,784))

  prediction = tf.argmax(y_conv,1)
  print("result: %g"%prediction.eval(feed_dict={x: new_img, keep_prob: 1.0}, session=sess))

else:
  学習
  saver.save(sess, "model.ckpt")

saverとckptは参考ページ通り.画像はPILを使って読み取り,グレースケールに変換後numpyのndarrayに代入して元のテストデータと同じ形に計算/reshapeしている.あとは,モデルに新しい画像を食わせてやり,その結果を出力するだけ.

結果は次の通り

% python saver.py 
load ./model.ckpt
result: 2

ちゃんと「2」と認識してくれた.嬉しい.とても嬉しい.CNNめちゃめちゃ可愛い.

ちなみに,学習100回のモデルに画像を計4枚作って読ませてみたが,1枚だけ誤認識した(「9」が7).友人は「4」が1であると認識されて,文字の太さが原因だったのではないかと考察している.今回は入力画像の前処理などを一切していなかったので妥当かも.

とにもかくにも,自分で作った画像を認識してくれて嬉しいので今日はこの辺で.