walkingmask’s development log

IT系の情報などを適当に書いていきます

MENU

Saver.save で保存したモデル (checkpoint, data, index, meta) をプロトコルバッファ (pb) に変換する

かなり泥臭い方法ですが、自分なりに解決した記録です。

前提として、モデルの定義部分の Python コードが必要です。

要約すると

  1. モデル定義のコードと meta などからグラフ作成
  2. tf.get_default_graph().as_graph_def().node を print してそれっぽいのを探す
  3. tf.graph_util.convert_variables_to_constants で pb ファイル生成

はじめに

Saver.save でモデルを保存すると、以下のようなファイルが保存ディレクトリ先に生成されると思います。

  • checkpoint
  • foo.data-XXX-of-XXX
  • foo.index
  • foo.meta

過去に自分で学習を行ったとか、公開されている学習済みモデルなど、この形式を使いたい場合があります。

しかし、これは TensorFlow.js などで使用することができません (2018/06/30 現在)。

そこで、これを プロトコルバッファ形式である foo.pb に変換しました。

今回、対象にしたのはこちらの microexpnet のモデルです。

手順

pb ファイルを生成するには、tf.graph_util.convert_variables_to_constants を使うと良さそうです。

stackoverflow.com

しかし、これを使うには出力層のグラフ内でのノード名が必要となります。bazel を使って graph_transforms/summarize_graph というツールをビルドして調べる方法もあるようなのですが、brew でインストールしたりが嫌だったので、泥臭く調べることにしました(本当はこちらが良いと思います)。

blogs.yahoo.co.jp

まず、モデル定義のコードと meta などからグラフを作成します。(重みの読み込みに関するコードがあれば、それも使います)

import tensorflow as tf
from MicroExpNet import MicroExpNet

modelDir = './Models/OuluCASIA'

x = tf.placeholder(tf.float32, shape=[None, 84*84])
classifier = MicroExpNet(x)
weights_biases_deployer = tf.train.Saver({"wc1": classifier.w["wc1"], \
                                          "wc2": classifier.w["wc2"], \
                                          "wfc": classifier.w["wfc"], \
                                          "wo":  classifier.w["out"], \
                                          "bc1": classifier.b["bc1"], \
                                          "bc2": classifier.b["bc2"], \
                                          "bfc": classifier.b["bfc"], \
                                          "bo":  classifier.b["out"]})

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    weights_biases_deployer.restore(sess, tf.train.latest_checkpoint(modelDir))

そして、print('\n'.join([n.name for n in tf.get_default_graph().as_graph_def().node]) で、それらしいノードを探します。

以上です。

ポイントとしては、"out" や "bo", "wo" といったキーワードや、モデル定義の最終層を手がかりに探すことです。

microexpnetの場合は、"Add_1" が最終出力ノードでした。

これを pb にするには、sess の中で

    frozen_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        sess.graph_def,
        output_node_names)

    with open('output_graph.pb', 'wb') as f:
      f.write(frozen_graph_def.SerializeToString())

だけで pb ファイルができます。(output_node_names = ['Add_1'])

また、余談ですが (tensorflowjs で frozen_model オプションの時に必要) input_node_name はだいたい "Placeholder" で定義されています(当たり前っちゃ当たり前ですが)。

おわりに

泥臭い方法で、Saver.save で保存したモデルをプロトコルバッファ形式に変換しました。

素直に bazel インストールして使えよって話はありますが、そういうのが嫌な時もありますよね!

次回からは素直に生きれたらいいなと思います😊

参考文献