tensorflow といえば Python と思っておられる方も多いのではないでしょうか。間違いではないのですが、これは初期に作られた Python 向けのバインディングに研究者達が多く食いついた結果に過ぎないと思っています。実際 tensorflow は現在 C言語、C++、Python、Java、Go から利用する事ができ、最近では JavaScript にも移植されています。筆者自身も Go で tensorflow を使ったシステムを構築し、運用保守しています。問題も発生せず機嫌良く動いています。学習の利用部分は GPU のパフォーマンスに依存しますが、それ以外の部分については各言語の実装に依存します。上手く作れば Python よりも速い実装にする事も可能です。とても便利な世の中になってきたのですが、学習環境に関しては Python 界で培われてきた資産が沢山あるので、それを使うのがやはり効率が良いです。しかしながら最近では、既に色々な研究者がモデルファイルを再利用可能な形で配布しており、独自のモデルを扱わないのであれば C/C++、Java や Go 等を使ってリーズナブルに認識処理を実装する事ができる様になりました。
Go で tensorflow をやろうとするユーザがもっと増えればいいなと思います。そこで今回は、Go でどうやって tensorflow を扱うかを、画像の認識プログラムを実装しながら説明したいと思います。
Go で tensorflow を使う為には libtensorflow.so が必要です。各 OS 向けにコンパイル済みのバイナリがあります。libtensorflow のインストールは以下のページが参考になります。
Installing TensorFlow for C | TensorFlow
TensorFlow provides a C API defined in c_api.h , which is suitable for building bindings for other l...
https://www.tensorflow.org/install/install_c
libtensorflow.so とヘッダファイルをシステムにインストールすればあとは go get で Go のパッケージがインストール出来ます。
go get github.com/tensorflow/tensorflow/tensorflow/go
最近の Go ではビルド結果がキャッシュされるので、一度 go get しておけば tensorflow の cgo パッケージが毎回ビルドされる事はありません。
さて、画像の認識プログラムを実装するには、まず物体を検出する為のモデルを持ってくる必要があります。ありがたい事に tensorflow のリポジトリで物体検出の学習済みモデルを提供してくれています。
models/research/object_detection at master - tensorflow/models - GitHub
Tensorflow Object Detection API Creating accurate machine learning models capable of localizing and ...
https://github.com/tensorflow/models/tree/master/research/object_detection
今回はこのリポジトリに含まれる ssd_mobilenet_v1_coco_11_06_2017.tar.gz を使います。このモデルは COCO (Common Objects in Context) で選定されたオブジェクトを Single Shot (Object) Detection する目的で作られた物で、以下の Jupyter Notebook からその手順を参照する事が出来ます。
models/object_detection_tutorial.ipynb at master - tensorflow/models - GitHub
YknZhu-patch-1 a3c_blogpost achowdhery-patch-1 achowdhery-patch-2 aselle-patch-1 asimshankar-patch-1...
https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb
もしご自分で学習してみたい方は以下のページが参考になります。
Quick Start: Distributed Training on the Oxford-IIIT Pets Dataset on Google Cloud
Quick Start: Distributed Training on the Oxford-IIIT Pets Dataset on Google Cloud This page is a wal...
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_pets.md
さらにご自分でアノテーションもやってみたい方はこのあたりを参考に XML ファイルを作る必要があります。なかなか骨が折れる作業ですが、うまくいくとこんな事が出来る様になります。
python でウォーリーのモデルを学習させて #golang の tensorflow で検出、image/draw でマスクしてみた。やったね✌ pic.twitter.com/vWn0J4CC5G
— mattn (@mattn_jp) August 15, 2018
この学習済みモデルは、入力がデコード済みの画像バイナリ image_tensor、出力が認識した矩形 detection_boxes、認識したクラス名 detection_classes、認識率 detection_scores、認識個数 num_detections になります。クラス名は番号で帰ってくるので、このモデルを生成した際に使われたクラス名のリストが書かれたファイル coco_labels.txt も取得しておきます。
なお、Go で Python の学習済みモデルを利用するには ckpt 形式でなく Protocol Buffers の形式でなければなりません。
model, err := ioutil.ReadFile(filepath.Join(dir, "frozen_inference_graph.pb"))
if err != nil {
log.Fatal(err)
}
labels, err := loadLabels(filepath.Join(dir, "coco_labels.txt"))
if err != nil {
log.Fatal(err)
}
graph := tf.NewGraph()
if err := graph.Import(model, ""); err != nil {
log.Fatal(err)
}
session, err := tf.NewSession(graph, nil)
if err != nil {
log.Fatal(err)
}
defer session.Close()
次に画像ファイルからテンソルを作る必要があります。tensorflow/go/op パッケージに各種画像フォーマットに対するデコーダがありますので、以下の様に画像バイナリからノーマライズしつつテンソルを得る関数を作ります。
func decodeBitmapGraph() (*tf.Graph, tf.Output, tf.Output, error) {
s := op.NewScope()
input := op.Placeholder(s, tf.String)
output := op.ExpandDims(
s,
op.DecodeBmp(s, input, op.DecodeBmpChannels(3)),
op.Const(s.SubScope("make_batch"), int32(0)))
graph, err := s.Finalize()
return graph, input, output, err
}
今回の例ではやっていませんが、例えば画像をリサイズしつつ、偏差を用いてノーマライズしてテンソルを得るグラフであれば以下の様に実装します。
// Div and Sub perform (value-Mean)/Scale for each pixel
output := op.Div(s,
op.Sub(s,
// Resize to 224x224 with bilinear interpolation
op.ResizeBilinear(s,
// Create a batch containing a single image
op.ExpandDims(s,
// Use decoded pixel values
op.Cast(s, decode, tf.Float),
op.Const(s.SubScope("make_batch"), int32(0))),
op.Const(s.SubScope("size"), []int32{H, W})),
op.Const(s.SubScope("mean"), Mean)),
op.Const(s.SubScope("scale"), Scale))
これは変換処理自身の実装ではなく、GPU に画像の変換機を送り込む為の準備です。この変換機を使って、画像のバイナリからテンソルを作ります。ちなみに以下の関数では Go の image.Image も作っていますが、これは後で画像にマーカーを描き込む為のベース画像として利用する為です。
func makeTensorFromImage(img []byte) (*tf.Tensor, image.Image, error) {
tensor, err := tf.NewTensor(string(img))
if err != nil {
return nil, nil, err
}
normalizeGraph, input, output, err := decodeBitmapGraph()
if err != nil {
return nil, nil, err
}
normalizeSession, err := tf.NewSession(normalizeGraph, nil)
if err != nil {
return nil, nil, err
}
defer normalizeSession.Close()
normalized, err := normalizeSession.Run(
map[tf.Output]*tf.Tensor{input: tensor},
[]tf.Output{output},
nil)
if err != nil {
return nil, nil, err
}
r := bytes.NewReader(img)
i, _, err := image.Decode(r)
if err != nil {
return nil, nil, err
}
return normalized[0], i, nil
}
テンソルが得られたら後は上記の通り、image_tensor を入力、detection_boxes, detection_scores, detection_classes, num_detections を出力としたセッションを実行します。
func detectObjects(session *tf.Session, graph *tf.Graph, input *tf.Tensor) ([]float32, []float32, [][]float32, error) {
inputop := graph.Operation("image_tensor")
output, err := session.Run(
map[tf.Output]*tf.Tensor{
inputop.Output(0): input,
},
[]tf.Output{
graph.Operation("detection_boxes").Output(0),
graph.Operation("detection_scores").Output(0),
graph.Operation("detection_classes").Output(0),
graph.Operation("num_detections").Output(0),
},
nil)
if err != nil {
return nil, nil, nil, fmt.Errorf("Error running session: %v", err)
}
probabilities := output[1].Value().([][]float32)[0]
classes := output[2].Value().([][]float32)[0]
boxes := output[0].Value().([][][]float32)[0]
return probabilities, classes, boxes, nil
}
後は得られた結果から元画像の座標に変換し、四角やテキストを書き込みます。
probabilities, classes, boxes, err := detectObjects(session, graph, tensor)
if err != nil {
log.Fatalf("error making prediction: %v", err)
}
bounds := img.Bounds()
canvas := image.NewRGBA(bounds)
draw.Draw(canvas, bounds, img, image.Pt(0, 0), draw.Src)
i := 0
for float64(probabilities[i]) > probability {
idx := int(classes[i])
y1 := int(float64(bounds.Min.Y) + float64(bounds.Dy())*float64(boxes[i][0]))
x1 := int(float64(bounds.Min.X) + float64(bounds.Dx())*float64(boxes[i][1]))
y2 := int(float64(bounds.Min.Y) + float64(bounds.Dy())*float64(boxes[i][2]))
x2 := int(float64(bounds.Min.X) + float64(bounds.Dx())*float64(boxes[i][3]))
drawRect(canvas, image.Rect(x1, y1, x2, y2), color.RGBA{255, 0, 0, 0})
drawString(
canvas,
image.Pt(x1, y1),
colornames.Map[colornames.Names[idx]],
fmt.Sprintf("%s (%2.0f%%)", labels[idx], probabilities[idx]*100.0))
i++
}
実行は以下の様に行います。
./go-object-detect-from-image input.jpg
実行すると output.jpg というファイルが出力され、以下の様にマーカーが表示されます。
コードは GitHub に置いてあります。
GitHub - mattn/go-object-detect-from-image
go-object-detect-from-image detect objects from image file Usage ./go-object-detect-from-image input...
https://github.com/mattn/go-object-detect-from-image
モデルファイルとラベルファイルも同梱してあるので、ビルドすればそのまま実行出来る様にしてあります。要点だけ掴めばそれほど難しくないですし、他のモデルファイルを使って色々な物体検出を試す事も出来ます。Go で tensorflow を使っているユーザはこれからもっと増えると思います。ぜひ面白いプログラムを作ってみて下さい。