2019/03/07

Recent entries from same category

  1. Go 言語プログラミングエッセンスという本を書きました。
  2. errors.Join が入った。
  3. unsafe.StringData、unsafe.String、unsafe.SliceData が入った。
  4. Re: Go言語で画像ファイルか確認してみる
  5. net/url に JoinPath が入った。

Google launches TensorFlow Lite 1.0 for mobile and embedded devices | VentureBeat

Google today introduced TensorFlow Lite 1.0 , its framework for developers deploying AI models on mo...

https://venturebeat.com/2019/03/06/google-launches-tensorflow-lite-1-0-for-mobile-and-embeddable-devices/

TensorFlow Lite 1.0 がリリースされた、このめでたい日に TensorFlow にプレゼントを送ろうと思って作りました。

TensorFlow Lite を Go 言語から扱えるパッケージです。なるべく C の API に忠実に実装したので Go 言語感がない部分もあるかもしれませんが、それはこれからです。

GitHub - mattn/go-tflite

go-tflite Go binding for TensorFlow Lite Usage See _example/main.go Requirements TensorFlow Lite Ins...

https://github.com/mattn/go-tflite

使い方は TensorFlow Lite でプログラミングした事がある方なら分かるはずです。

package main

import (
    "bufio"
    "flag"
    "fmt"
    "image"
    _ "image/png"
    "log"
    "os"
    "sort"

    "github.com/mattn/go-tflite"
    "github.com/nfnt/resize"
)

func loadLabels(filename string) ([]stringerror) {
    labels := []string{}
    f, err := os.Open("labels.txt")
    if err != nil {
        return nil, err
    }
    defer f.Close()
    scanner := bufio.NewScanner(f)
    for scanner.Scan() {
        labels = append(labels, scanner.Text())
    }
    return labels, nil
}

func main() {
    var model_path, label_path, image_path string
    flag.StringVar(&model_path, "model""mobilenet_quant_v1_224.tflite""path to model file")
    flag.StringVar(&label_path, "label""labels.txt""path to label file")
    flag.StringVar(&image_path, "image""grace_hopper.png""path to image file")
    flag.Parse()

    f, err := os.Open(image_path)
    if err != nil {
        log.Fatal(err)
    }
    defer f.Close()

    img, _, err := image.Decode(f)
    if err != nil {
        log.Fatal(err)
    }

    labels, err := loadLabels(label_path)
    if err != nil {
        log.Fatal(err)
    }

    model := tflite.NewModelFromFile(model_path)
    if model == nil {
        log.Fatal("cannot load model")
    }
    defer model.Delete()

    options := tflite.NewInterpreterOptions()
    options.SetNumThread(4)

    interpreter := tflite.NewInterpreter(model, options)
    if interpreter == nil {
        log.Fatal("cannot create interpreter")
    }
    defer interpreter.Delete()

    status := interpreter.AllocateTensors()
    if status != tflite.OK {
        log.Fatal("allocate failed")
    }

    input := interpreter.GetInputTensor(0)
    wanted_height := input.Dim(1)
    wanted_width := input.Dim(2)
    wanted_channels := input.Dim(3)
    wanted_type := input.Type()
    fmt.Println(wanted_height, wanted_width, wanted_channels, wanted_type)

    resized := resize.Resize(uint(wanted_width), uint(wanted_height), img, resize.NearestNeighbor)
    bounds := resized.Bounds()
    dx, dy := bounds.Dx(), bounds.Dy()

    if wanted_type == tflite.UInt8 {
        bb := make([]byte, dx*dy*wanted_channels)
        for y := 0; y < dy; y++ {
            for x := 0; x < dx; x++ {
                col := resized.At(x, y)
                r, g, b, _ := col.RGBA()
                bb[(y*dx+x)*3+0= byte(float64(r) / 255.0)
                bb[(y*dx+x)*3+1= byte(float64(g) / 255.0)
                bb[(y*dx+x)*3+2= byte(float64(b) / 255.0)
            }
        }
        input.CopyFromBuffer(bb)
    } else {
        log.Fatal("is not wanted type")
    }

    status = interpreter.Invoke()
    if status != tflite.OK {
        log.Fatal("invoke failed")
    }

    output := interpreter.GetOutputTensor(0)
    output_size := output.Dim(output.NumDims() - 1)
    b := make([]byte, output_size)
    type result struct {
        score float64
        index int
    }
    status = output.CopyToBuffer(b)
    if status != tflite.OK {
        log.Fatal("output failed")
    }
    results := []result{}
    for i := 0; i < output_size; i++ {
        score := float64(b[i]) / 255.0
        if score < 0.2 {
            continue
        }
        results = append(results, result{score: score, index: i})
    }
    sort.Slice(results, func(i, j intbool {
        return results[i].score > results[j].score
    })
    for i := 0; i < len(results); i++ {
        fmt.Printf("%02d%s%f\n", results[i].index, labels[results[i].index], results[i].score)
        if i > 5 {
            break
        }
    }
}

label_image でおなじみの grace_hopper も正しく動きます。

grace_hopper

一応、TensorFlow Lite の C API は全て移植したつもりですが、何かバグっていたら教えて下さい。

Posted at by | Edit