なんとなく SQLite3 でロジスティック回帰できたら面白そうと思ったので作ってみた。
データセットは iris、sqlflow の DDL を使わせて頂いた。
sqlflow/example/datasets at develop · sql-machine-learning/sqlflow - GitHub
It should print the number of rows as the following: count(*) 10 Troubleshooting It usually takes ab...
https://github.com/sql-machine-learning/sqlflow/tree/develop/example/datasets
CREATE DATABASE IF NOT EXISTS iris;
DROP TABLE IF EXISTS iris.train;
CREATE TABLE iris.train (
sepal_length float,
sepal_width float,
petal_length float,
petal_width float,
class int);
INSERT INTO iris.train VALUES(6.4,2.8,5.6,2.2,2);
INSERT INTO iris.train VALUES(5.0,2.3,3.3,1.0,1);
INSERT INTO iris.train VALUES(4.9,2.5,4.5,1.7,2);
INSERT INTO iris.train VALUES(4.9,3.1,1.5,0.1,0);
INSERT INTO iris.train VALUES(5.7,3.8,1.7,0.3,0);
INSERT INTO iris.train VALUES(4.4,3.2,1.3,0.2,0);
INSERT INTO iris.train VALUES(5.4,3.4,1.5,0.4,0);
INSERT INTO iris.train VALUES(6.9,3.1,5.1,2.3,2);
INSERT INTO iris.train VALUES(6.7,3.1,4.4,1.4,1);
INSERT INTO iris.train VALUES(5.1,3.7,1.5,0.4,0);
INSERT INTO iris.train VALUES(5.2,2.7,3.9,1.4,1);
INSERT INTO iris.train VALUES(6.9,3.1,4.9,1.5,1);
INSERT INTO iris.train VALUES(5.8,4.0,1.2,0.2,0);
INSERT INTO iris.train VALUES(5.4,3.9,1.7,0.4,0);
INSERT INTO iris.train VALUES(7.7,3.8,6.7,2.2,2);
INSERT INTO iris.train VALUES(6.3,3.3,4.7,1.6,1);
INSERT INTO iris.train VALUES(6.8,3.2,5.9,2.3,2);
INSERT INTO iris.train VALUES(7.6,3.0,6.6,2.1,2);
INSERT INTO iris.train VALUES(6.4,3.2,5.3,2.3,2);
INSERT INTO iris.train VALUES(5.7,4.4,1.5,0.4,0);
INSERT INTO iris.train VALUES(6.7,3.3,5.7,2.1,2);
INSERT INTO iris.train VALUES(6.4,2.8,5.6,2.1,2);
INSERT INTO iris.train VALUES(5.4,3.9,1.3,0.4,0);
INSERT INTO iris.train VALUES(6.1,2.6,5.6,1.4,2);
INSERT INTO iris.train VALUES(7.2,3.0,5.8,1.6,2);
INSERT INTO iris.train VALUES(5.2,3.5,1.5,0.2,0);
INSERT INTO iris.train VALUES(5.8,2.6,4.0,1.2,1);
INSERT INTO iris.train VALUES(5.9,3.0,5.1,1.8,2);
INSERT INTO iris.train VALUES(5.4,3.0,4.5,1.5,1);
INSERT INTO iris.train VALUES(6.7,3.0,5.0,1.7,1);
INSERT INTO iris.train VALUES(6.3,2.3,4.4,1.3,1);
INSERT INTO iris.train VALUES(5.1,2.5,3.0,1.1,1);
INSERT INTO iris.train VALUES(6.4,3.2,4.5,1.5,1);
INSERT INTO iris.train VALUES(6.8,3.0,5.5,2.1,2);
INSERT INTO iris.train VALUES(6.2,2.8,4.8,1.8,2);
INSERT INTO iris.train VALUES(6.9,3.2,5.7,2.3,2);
INSERT INTO iris.train VALUES(6.5,3.2,5.1,2.0,2);
INSERT INTO iris.train VALUES(5.8,2.8,5.1,2.4,2);
INSERT INTO iris.train VALUES(5.1,3.8,1.5,0.3,0);
INSERT INTO iris.train VALUES(4.8,3.0,1.4,0.3,0);
INSERT INTO iris.train VALUES(7.9,3.8,6.4,2.0,2);
INSERT INTO iris.train VALUES(5.8,2.7,5.1,1.9,2);
INSERT INTO iris.train VALUES(6.7,3.0,5.2,2.3,2);
INSERT INTO iris.train VALUES(5.1,3.8,1.9,0.4,0);
INSERT INTO iris.train VALUES(4.7,3.2,1.6,0.2,0);
INSERT INTO iris.train VALUES(6.0,2.2,5.0,1.5,2);
INSERT INTO iris.train VALUES(4.8,3.4,1.6,0.2,0);
INSERT INTO iris.train VALUES(7.7,2.6,6.9,2.3,2);
INSERT INTO iris.train VALUES(4.6,3.6,1.0,0.2,0);
INSERT INTO iris.train VALUES(7.2,3.2,6.0,1.8,2);
INSERT INTO iris.train VALUES(5.0,3.3,1.4,0.2,0);
INSERT INTO iris.train VALUES(6.6,3.0,4.4,1.4,1);
INSERT INTO iris.train VALUES(6.1,2.8,4.0,1.3,1);
INSERT INTO iris.train VALUES(5.0,3.2,1.2,0.2,0);
INSERT INTO iris.train VALUES(7.0,3.2,4.7,1.4,1);
INSERT INTO iris.train VALUES(6.0,3.0,4.8,1.8,2);
INSERT INTO iris.train VALUES(7.4,2.8,6.1,1.9,2);
INSERT INTO iris.train VALUES(5.8,2.7,5.1,1.9,2);
INSERT INTO iris.train VALUES(6.2,3.4,5.4,2.3,2);
INSERT INTO iris.train VALUES(5.0,2.0,3.5,1.0,1);
INSERT INTO iris.train VALUES(5.6,2.5,3.9,1.1,1);
INSERT INTO iris.train VALUES(6.7,3.1,5.6,2.4,2);
INSERT INTO iris.train VALUES(6.3,2.5,5.0,1.9,2);
INSERT INTO iris.train VALUES(6.4,3.1,5.5,1.8,2);
INSERT INTO iris.train VALUES(6.2,2.2,4.5,1.5,1);
INSERT INTO iris.train VALUES(7.3,2.9,6.3,1.8,2);
INSERT INTO iris.train VALUES(4.4,3.0,1.3,0.2,0);
INSERT INTO iris.train VALUES(7.2,3.6,6.1,2.5,2);
INSERT INTO iris.train VALUES(6.5,3.0,5.5,1.8,2);
INSERT INTO iris.train VALUES(5.0,3.4,1.5,0.2,0);
INSERT INTO iris.train VALUES(4.7,3.2,1.3,0.2,0);
INSERT INTO iris.train VALUES(6.6,2.9,4.6,1.3,1);
INSERT INTO iris.train VALUES(5.5,3.5,1.3,0.2,0);
INSERT INTO iris.train VALUES(7.7,3.0,6.1,2.3,2);
INSERT INTO iris.train VALUES(6.1,3.0,4.9,1.8,2);
INSERT INTO iris.train VALUES(4.9,3.1,1.5,0.1,0);
INSERT INTO iris.train VALUES(5.5,2.4,3.8,1.1,1);
INSERT INTO iris.train VALUES(5.7,2.9,4.2,1.3,1);
INSERT INTO iris.train VALUES(6.0,2.9,4.5,1.5,1);
INSERT INTO iris.train VALUES(6.4,2.7,5.3,1.9,2);
INSERT INTO iris.train VALUES(5.4,3.7,1.5,0.2,0);
INSERT INTO iris.train VALUES(6.1,2.9,4.7,1.4,1);
INSERT INTO iris.train VALUES(6.5,2.8,4.6,1.5,1);
INSERT INTO iris.train VALUES(5.6,2.7,4.2,1.3,1);
INSERT INTO iris.train VALUES(6.3,3.4,5.6,2.4,2);
INSERT INTO iris.train VALUES(4.9,3.1,1.5,0.1,0);
INSERT INTO iris.train VALUES(6.8,2.8,4.8,1.4,1);
INSERT INTO iris.train VALUES(5.7,2.8,4.5,1.3,1);
INSERT INTO iris.train VALUES(6.0,2.7,5.1,1.6,1);
INSERT INTO iris.train VALUES(5.0,3.5,1.3,0.3,0);
INSERT INTO iris.train VALUES(6.5,3.0,5.2,2.0,2);
INSERT INTO iris.train VALUES(6.1,2.8,4.7,1.2,1);
INSERT INTO iris.train VALUES(5.1,3.5,1.4,0.3,0);
INSERT INTO iris.train VALUES(4.6,3.1,1.5,0.2,0);
INSERT INTO iris.train VALUES(6.5,3.0,5.8,2.2,2);
INSERT INTO iris.train VALUES(4.6,3.4,1.4,0.3,0);
INSERT INTO iris.train VALUES(4.6,3.2,1.4,0.2,0);
INSERT INTO iris.train VALUES(7.7,2.8,6.7,2.0,2);
INSERT INTO iris.train VALUES(5.9,3.2,4.8,1.8,1);
INSERT INTO iris.train VALUES(5.1,3.8,1.6,0.2,0);
INSERT INTO iris.train VALUES(4.9,3.0,1.4,0.2,0);
INSERT INTO iris.train VALUES(4.9,2.4,3.3,1.0,1);
INSERT INTO iris.train VALUES(4.5,2.3,1.3,0.3,0);
INSERT INTO iris.train VALUES(5.8,2.7,4.1,1.0,1);
INSERT INTO iris.train VALUES(5.0,3.4,1.6,0.4,0);
INSERT INTO iris.train VALUES(5.2,3.4,1.4,0.2,0);
INSERT INTO iris.train VALUES(5.3,3.7,1.5,0.2,0);
INSERT INTO iris.train VALUES(5.0,3.6,1.4,0.2,0);
INSERT INTO iris.train VALUES(5.6,2.9,3.6,1.3,1);
INSERT INTO iris.train VALUES(4.8,3.1,1.6,0.2,0);
DROP TABLE IF EXISTS iris.test;
CREATE TABLE iris.test (
sepal_length float,
sepal_width float,
petal_length float,
petal_width float,
class int);
INSERT INTO iris.test VALUES(6.3,2.7,4.9,1.8,2);
INSERT INTO iris.test VALUES(5.7,2.8,4.1,1.3,1);
INSERT INTO iris.test VALUES(5.0,3.0,1.6,0.2,0);
INSERT INTO iris.test VALUES(6.3,3.3,6.0,2.5,2);
INSERT INTO iris.test VALUES(5.0,3.5,1.6,0.6,0);
INSERT INTO iris.test VALUES(5.5,2.6,4.4,1.2,1);
INSERT INTO iris.test VALUES(5.7,3.0,4.2,1.2,1);
INSERT INTO iris.test VALUES(4.4,2.9,1.4,0.2,0);
INSERT INTO iris.test VALUES(4.8,3.0,1.4,0.1,0);
INSERT INTO iris.test VALUES(5.5,2.4,3.7,1.0,1);
僕が作ってる Go の SQLite3 ドライバはユーザ関数を Go で書く事が出来る。
sql.Register("sqlite3_custom", &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
if err := conn.RegisterAggregator("logistic_regression_train", createLogisticRegressionTrain(conn), true); err != nil {
return err
}
if err := conn.RegisterFunc("logistic_regression_predict", createLogisticRegressionPredict(conn), true); err != nil {
return err
}
return nil
},
})
db, err := sql.Open("sqlite3_custom", ":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`attach "iris.sqlite" as iris`)
if err != nil {
log.Fatal(err)
}
ユーザ関数とアグリゲート関数は動作が異なっていて、ユーザ関数は SELECT で使うと行毎に呼び出され、行毎の結果が返る。アグリゲート関数は行毎に Step メソッドが呼ばれ、最後に Done メソッドが呼ばれる。つまり集計関数になる。アグリゲート関数で以下の様に SELECT した結果を全て貰いモデルを作る。モデルは JSON 形式で出力する様にした。文字列を持ったテーブルにそのまま突っ込める。これを logistic_regression_train という関数名にした。
_, err = db.Exec(`
drop table if exists iris.model;
create table iris.model(config text);
insert into iris.model
select
logistic_regression_train('{
"rate": 0.1,
"ntrains": 5000
}',
sepal_length,
sepal_width,
petal_length,
petal_width,
class
)
from
iris.train
`)
if err != nil {
log.Fatal(err)
}
次にこの JSON からモデルに戻し、引数で渡されたテストデータから推論する関数 logistics_regression_predict を作った。
rows, err := db.Query(`
select
logistic_regression_predict('iris.model',
sepal_length,
sepal_width,
petal_length,
petal_width
), class
from
iris.test
`)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var predicted, class float64
err = rows.Scan(&predicted, &class)
if err != nil {
log.Fatal(err)
}
fmt.Println(math.RoundToEven(predicted), class)
}
ロジスティック回帰そのものは gonum を使って書いた。
func (s *logistic_regression) Done() (string, error) {
ws := make([]float64, s.X[0].Len())
for i := range ws {
ws[i] = s.rand.Float64()
}
for i := range s.y {
s.y[i] = s.y[i] / (s.maxy + 1)
}
w := mat.NewVecDense(len(ws), ws)
y := mat.NewVecDense(len(s.y), s.y)
for n := 0; n < s.cfg.NTrains; n++ {
for i, x := range s.X {
t := mat.NewVecDense(x.Len(), nil)
t.CopyVec(x)
pred := softmax(t, w)
perr := y.AtVec(i) - pred
scale := s.cfg.Rate * perr * pred * (1 - pred)
for j := 0; j < x.Len(); j++ {
dx := mat.NewVecDense(x.Len(), nil)
dx.CopyVec(x)
dx.ScaleVec(scale, x)
w.AddVec(w, dx)
}
}
}
fargs := make([]float64, w.Len())
for i := 0; i < w.Len(); i++ {
fargs[i] = w.AtVec(i)
}
var buf bytes.Buffer
err := json.NewEncoder(&buf).Encode(&model{
W: fargs,
M: s.maxy,
})
if err != nil {
return "", err
}
return buf.String(), nil
}
この例では推論した値 predict と、正解の値 class が SELECT されるので Go で値を取り出すと推論が正しいか判断できる。
for rows.Next() {
var predicted, class float64
err = rows.Scan(&predicted, &class)
if err != nil {
log.Fatal(err)
}
fmt.Printf(
"predict: %d (%d)\n",
int(math.RoundToEven(predicted)), int(class))
}
predict: 1 (2)
predict: 1 (1)
predict: 0 (0)
predict: 2 (2)
predict: 0 (0)
predict: 1 (1)
predict: 1 (1)
predict: 0 (0)
predict: 0 (0)
predict: 1 (1)
正解率 90% なのでまずまずと言っていいのかな。
サンプルコードの位置づけだけど GitHub にコードを置いておきます。
GitHub - mattn/go-sqlite3-logistics-regression
Features → Code review Project management Integrations Actions Package registry Team management...
https://github.com/mattn/go-sqlite3-logistics-regression