読者です 読者をやめる 読者になる 読者になる

白猫のメモ帳

JavaとかJavaScriptとかHTMLとか機械学習とか。

k近傍法をJavaで実装する

Java 機械学習

こんばんは。

暑いですね。
とろけてしまいそうです。

さて、前回は最近傍法を試してみましたが、今回はk近傍法を実装してみましょう。

別のクラスにしてもいいのですが、
最近傍法はk=1ですので、共通化しておきましょう。

kはコンストラクタ引数で渡すことにします。

public class NearestNeighbor implements LearningMachine {
    
    /** 最近傍採用数 */
    private final int k;
    
    /** 認識したパターン */
    private final List<Pair<Integer, double[]>> learning = new ArrayList<>();

    /**
     * コンストラクタ
     */
    public NearestNeighbor(int k) {
        this.k = k;
    }
    
    /**
     * 教師データを追加
     */
    @Override
    public void add(int label, double[] feature) {
        this.learning.add(new Pair(label, feature));
    }
    
    /** 
     * 学習 
     */
    @Override
    public void learn() {
    }

    /**
     * 描画
     */
    @Override
    public void draw(GraphicsContext gc) {
        
        int w = (int) gc.getCanvas().getWidth();
        int h = (int) gc.getCanvas().getHeight();
        
        for (int x = 0; x < w; x += 2) {
            for (int y = 0; y < h; y += 2) {
                int ans = this.trial(new double[]{x, y});
                if (ans == 0) {
                    continue;
                }
                gc.setFill(ans > 0 ? Color.BLUE : Color.RED);
                gc.fillOval(x, y, 1, 1);
            }
        }
    }
    
    /**
     * 評価
     */
    @Override
    public int predict(double[] feature) {
        
        // 距離とクラスのランキング
        Map<Double, Integer> ranking = new TreeMap<>();
        
        // 一番近いパターンを求める
        this.learning.stream().filter(entry -> entry.getValue().length == feature.length)
                              .forEach((entry) -> {
            
            double[] pos = entry.getValue();
            
            // データ間の距離を求める(平方ユークリッド距離)
            double dist = Stream.iterate(0, i -> i + 1).limit(pos.length)
                                                       .mapToDouble(i -> Math.pow(pos[i] - feature[i], 2))
                                                       .sum();

            // ソートして貯める(ベクトルは特に必要ないので距離とクラス)
            ranking.put(dist, entry.getKey());
        });
        
        // クラスと個数でマッピング
        Map<Integer, Long> map = ranking.entrySet().stream().limit(k)
                                                            .collect(Collectors.groupingBy(x -> x.getValue(), Collectors.counting()));
        
        // 最大のクラスを取得
        return map.entrySet().stream().filter(e -> e.getValue().longValue() == Collections.max(map.values()).longValue())
                                      .mapToInt(e -> e.getKey())
                                      .findFirst().orElse(0);
    }
    
    /**
     * リセット
     */
    @Override
    public void reset() {
        this.learning.clear();
    }

    /**
     * タイトル
     */
    @Override
    public String getTitle() {
        return (k == 1 ? "最" : k) + "近傍法";
    }
}

単純に距離でソートして近い順に指定件数を取り出し、多数決を取るだけですね。

    /** 学習機 */
    private final LearningMachine lm = new NearestNeighbor(3);

コントローラはこんな感じになります。
画面上で変更できてもいいかなとも思ったのですが、インタフェースが変になりそうだったのやめました。


では試してみましょう。
最近傍法ではノイズデータに引っ張られてしまうのに対して、

f:id:Shiro-Neko:20160704202105j:plain

3近傍法では多数決によってノイズデータがうまい具合に無視されるようになりました。

f:id:Shiro-Neko:20160704202149j:plain

じゃあkの値はどんどん大きくすれば良くなるよね。
と思いきやそうでもないのです。

f:id:Shiro-Neko:20160704202259j:plain

母数に差が大きいと、多数派の浸食が…。


ところで、「k-nearest neighbor algorithm」なのにk近傍法でいいのでしょうか。
k最近傍法が正しいのでしょうか。はて…。