白猫のメモ帳

C#とかJavaとかJavaScriptとかHTMLとか機械学習とか。

最近傍法をJavaで試してみる

こんばんは。

最近はちょっと機械学習を勉強している私です。

機械が学習できれば私は学習しなくていいんですよね。
すばらしいです。
そのために私は勉強します。・・・ん?

さておき、とりあえず最初は機械学習アルゴリズムの中でももっとも単純な(ってWikipediaさんがいっている)
k近傍法、そのうちの更に一番簡単なk=1の最近傍法を試してみます。

k近傍法(ケイきんぼうほう、英: k-nearest neighbor algorithm, k-NN)は、
特徴空間における最も近い訓練例に基づいた分類の手法であり、パターン認識でよく使われる。最近傍探索問題の一つ。
k近傍法は、インスタンスに基づく学習の一種であり、怠惰学習 (lazy learning) の一種である。
その関数は局所的な近似に過ぎず、全ての計算は分類時まで後回しにされる。また、回帰分析にも使われる。

えーと…まぁ…なんというか…
つまり、「分類の手法」です。(雑)

f:id:Shiro-Neko:20160630205222j:plain

こういうのを

f:id:Shiro-Neko:20160630205251j:plain

こういう風にしたら分けられそうだよね。
みたいなことです。

f:id:Shiro-Neko:20160630205656j:plain

だからここはたぶん赤だよねみたいなことです。

さて、ではなんで「たぶん赤」みたいなことがいえるのでしょうか。
何となく赤に近い感じがするからですよね。

f:id:Shiro-Neko:20160630212253j:plain

これを何となくではなく、一番近い点に分類しようというのが最近傍法です。


実装はJavaFXで行います。

今回は最近傍法を利用しますが、学習方法を差し替えて使いまわしたいので、
まずは学習機のインタフェースを作ります。

public interface LearningMachine {
    
    /** 教師データを追加 */
    void add(int lavel, double[] feature);
    
    /** 学習 */
    void learn();
    
    /** 判定 */
    int predict(double[] feature);
    
    /** 描画 */
    void draw(GraphicsContext gc);
    
    /** リセット */
    void reset();
    
    /** タイトル(学習機の名前) */
    String getTitle();
}

どのように分類の線を引くかを描画したいためdrawという関数を作ってみました。
リセットでは教師データと学習結果をクリアします。
タイトルは今どの学習機がセットされているかわからなくなりそうだったのでつけました。

次にこのインタフェースを最近傍法のアルゴリズムで実装します。

public class NearestNeighbor implements LearningMachine {
       
    /** 認識したパターン */
    private final List<Pair<Integer, double[]>> learning = new ArrayList<>();
    
    /**
     * 教師データを追加
     */
    @Override
    public void add(int lavel, double[] feature) {
        this.learning.add(new Pair(lavel, feature));
    }
    
    /** 
     * 学習 
     */
    @Override
    public void learn() {
    }

    /**
     * 描画
     */
    @Override
    public void draw(GraphicsContext gc) {
        
        int w = (int) gc.getCanvas().getWidth();
        int h = (int) gc.getCanvas().getHeight();
        
        for (int x = 0; x < w; x += 2) {
            for (int y = 0; y < h; y += 2) {
                int ans = this.trial(new double[]{x, y});
                if (ans == 0) {
                    continue;
                }
                gc.setFill(ans > 0 ? Color.BLUE : Color.RED);
                gc.fillOval(x, y, 1, 1);
            }
        }
    }
    
    /**
     * 評価
     */
    @Override
    public int predict(double[] feature) {
        
        int clazz = 0;
        double min = Double.MAX_VALUE;

        for (Pair<Integer, double[]> entry : this.learning) {
            
            double[] pos = entry.getValue();
            
            // データ間の距離を求める(平方ユークリッド距離)
            double dist = Stream.iterate(0, i -> i + 1)
                                .limit(pos.length)
                                .mapToDouble(i -> Math.pow(pos[i] - feature[i], 2))
                                .sum();
            
            // 距離最小を探す
            if (dist < min) {
                min = dist;
                clazz = entry.getKey();
            }
        }
        
        return clazz;
    }
    
    /**
     * リセット
     */
    @Override
    public void reset() {
        this.learning.clear();
    }

    /**
     * タイトル
     */
    @Override
    public String getTitle() {
        return "最近傍法";
    }
}

一番最初に引用したWikipediaの説明にもある通り、k近傍法は怠惰学習のため学習メソッドは空っぽです。
評価メソッドの中で逐次計算していくため、教師データの母数が増えると恐ろしい勢いで遅くなります。

さて、学習機能はできましたが描画できないので、残りを適当に作ります。
JavaFXアプリって初めて作るので、変なことやっていたらごめんなさい。

まずはFXML。

<?xml version="1.0" encoding="UTF-8"?>

<?import java.lang.*?>
<?import java.net.URL ?>
<?import javafx.scene.canvas.*?>
<?import javafx.scene.text.*?>
<?import javafx.scene.control.*?>
<?import javafx.scene.layout.*?>
<?import javafx.geometry.*?>
<?import javafx.collections.*?>
 
 
<GridPane  xmlns="http://javafx.com/javafx/8" xmlns:fx="http://javafx.com/fxml/1" 
           fx:controller="fxpractice.PracticeController" hgap="10" vgap="10" styleClass="root" alignment="BASELINE_CENTER">
    
    <padding>
        <Insets top="10" right="10" bottom="10" left="10"/>
    </padding>

    <Label fx:id="title" alignment="CENTER" GridPane.columnIndex="0" GridPane.rowIndex="1" style="-fx-font-size: 16pt"/>
    
    <Canvas fx:id="canvas" width="400" height="400" GridPane.columnIndex="0" GridPane.rowIndex="2" onMouseClicked="#handleCanvasClick" />
    
    <fx:define>
        <ToggleGroup fx:id="toggle" />
    </fx:define>
    
    <HBox fx:id="button_area" spacing="10" alignment="CENTER" GridPane.columnIndex="0" GridPane.rowIndex="3">
        
        <padding>
            <Insets top="10" right="10" bottom="10" left="10"/>
        </padding>
        
        <RadioButton fx:id="radio1" text="青" toggleGroup="$toggle" userData="1" selected="true" />
        <RadioButton fx:id="radio2" text="赤" toggleGroup="$toggle" userData="-1" />
        <Button text="学習" prefWidth="80" onAction="#handleLearnButton" />
        <Button text="クリア" prefWidth="80" onAction="#handleClearButton" />

    </HBox>
    
</GridPane>

次にコントローラ。

public class PracticeController implements Initializable {
    
    /** タイトル */
    @FXML
    private Label title;
    
    /** キャンバス */
    @FXML 
    private Canvas canvas;
    
    /** チェックボックス */
    @FXML
    private ToggleGroup toggle;
    
    /** 学習機 */
    private final LearningMachine lm = new NearestNeighbor();
    
    /**
     * クリアボタン
     */
    @FXML
    protected void handleClearButton(ActionEvent event) {
        this.clear();
    }
    
    /**
     * 学習ボタン
     */
    @FXML
    protected void handleLearnButton(ActionEvent event) {
        
        GraphicsContext gc = canvas.getGraphicsContext2D();
        
        // 学習する
        this.lm.learn();

        // 描画する
        this.lm.draw(gc);
    }
    
    /**
     * キャンバスクリック
     */
    @FXML
    protected void handleCanvasClick(MouseEvent event) {
        
        // 赤と青のどちらを選んでいるか(教師ラベル)
        int val = Integer.parseInt((String)toggle.getSelectedToggle().getUserData());
        
        // 教師ラベルと教師データを設定
        this.lm.add(val, new double[]{event.getX(), event.getY()});
        
        // 画面に描画
        GraphicsContext gc = this.canvas.getGraphicsContext2D();
        gc.setFill(val > 0 ? Color.BLUE : Color.RED);
        gc.fillOval(event.getX(), event.getY(), 5, 5);
    }
    
    /**
     * 初期化
     */
    @Override
    public void initialize(URL url, ResourceBundle rb) {

        // クリア
        this.clear();
        
        // タイトルを設定
        this.title.setText(lm.getTitle());
    }
    
    /**
     * キャンバスをクリア
     */
    private void clear() {
        
        GraphicsContext gc = canvas.getGraphicsContext2D();
        
        // 全部消す
        gc.clearRect(0, 0, 400, 400);
        
        // 枠だけつくる
        gc.setFill(Color.WHITE);
        gc.setStroke(Color.GREEN);
        gc.fillRect(0, 0, 400, 400);
        gc.strokeRect(0, 0, 400, 400);
        
        // 学習機もリセットする
        this.lm.reset();
    }
}

で、メインクラス。

public class Practice extends Application {
    
    public static void main(String[] args) {
        launch(args);
    }
    
    @Override
    public void start(Stage primaryStage) throws IOException {
        
        Parent pane = FXMLLoader.load(getClass().getResource("practice.fxml"));
        
        Scene scene = new Scene(pane, 500, 520);
               
        primaryStage.setTitle("機械学習てすと!");
        primaryStage.setScene(scene);
        
        primaryStage.show();
    }
}

これで完成!
実行してみます。

f:id:Shiro-Neko:20160630233433j:plain

青と赤の点をそれぞれ適当に打ってから学習ボタンを押すと、学習の結果が表示されます。
学習と描画はインタフェース的には分けたかったけれど、
アプリケーション的には分けると面倒だったので、学習した瞬間にそれを描画しています。

f:id:Shiro-Neko:20160630233745j:plain

なんかいい感じに判定できているみたいですね。

f:id:Shiro-Neko:20160701193439j:plain

線形分離不可でも表現できます。

ずいぶん長くなってしまいましたが、今日はここまで。
次は3近傍法を試します。