白猫のメモ帳

JavaとかJavaScriptとかHTMLとか機械学習とか。

ラムダ式をメモ化する部品を作ろうとして混乱したお話

こんばんは。

クリスマス直前とは思えない暖かさですね。
大掃除は進んでいますでしょうか。

メモ化とフィボナッチ数列


さて、今回はメモ化のお話です。

メモ化とは、簡単に言えば、参照透過性を持つ関数の計算量を削減するためのキャッシュです。
引数に対する戻り値が一定の場合、それを覚えておいて、次回は計算せずに答えを返しましょうということですね。

メモ化のサンプルとしてよく出てくるのはフィボナッチ数列です。
これは最初の二項が0, 1であり、以後は直前の2つの項の和になる数列です。
具体的にはこんな感じですね。

0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377...

 

普通にフィボナッチ


というわけで、まずはメモ化せずにフィボナッチ数を求める関数を作ってみます。
負数はチェックした方がいいのかもしれませんが、今回は気にしないということで。
いわゆる「再帰」を使うと、シンプルに書くことができますね。

private static long fibonacci(int n) {
    return n <= 1 ? n : fibonacci(n - 2) + fibonacci(n - 1);
}

こんな感じで実行してみると、80秒くらいかかりました。

IntStream.range(0, 50)
         .mapToLong(n -> fibonacci(n))
         .forEach(System.out::println);

 

メモ化する


ではこれをメモ化してみます。こうかな。

private static Map<Integer, Long> map = new HashMap<>();
private static long memoizationFibonacci(int n) {
    return map.computeIfAbsent(n, i -> i <= 1 ? i : memoizationFibonacci(i - 2) + memoizationFibonacci(i - 1));
}

実行してみると、1秒未満になりました。やったね。

ラムダ式再帰する


せっかくなのでどんな関数でも同じことができるように、
関数を受け取るようにしてみたいと思います。

が、フィボナッチ関数には再帰があります。ラムダ式再帰する場合、どうすればいいかちょっと悩みますよね。
いったんメモ化機能は外して考えてみましょう。

// コンパイルエラー(関数の定義時にその関数自体は参照できない)
Function<Integer, Long> func = n -> n <= 1 ? n : func.apply(n - 2) + func.apply(n - 1));

というわけで、こうしてみます。

@FunctionalInterface
interface RecursiveFunction<T, R> {
    R apply(T t, RecursiveFunction<T, R> func);
}

RecursiveFunction<Integer, Long> func = 
        (n, f) -> n <= 1 ? n : (f.apply(n - 2, f) + f.apply(n - 1, f));

applyするタイミングで自身を引数で受けると参照できるよねということです。
呼び出し方はこうです。

IntStream.range(0, 50)
         .mapToLong(n -> func.apply(n, func))
         .forEach(System.out::println);

とりあえずラムダの再帰ができるようになりました。

メモ化を外からくっつける


さて、関数をメモ化したいので、
値を記録しておくためのMapと関数とフィールドに持つクラスを作ってみます。

class Memoizer<T, R> {
        
    private final Map<T, R> map = new HashMap<>();

    private final RecursiveFunction<T, R> func;

    Memoizer(RecursiveFunction<T, R> func) {
        this.func = func;
    }
}


次に、Tを受けてRを返す再帰型の関数を作る必要があるので、
こんなシグネチャのメソッドを作ればよいでしょうか。

R apply(T t, RecursiveFunction<T, R> func);

呼び出しはこうなりますね。

Memoizer<Integer, Long> memoizer = new Memoizer<>(func);
IntStream.range(0, 50)
         .mapToLong(n -> memoizer.apply(n, func))
         .forEach(System.out::println);


発想としてはmemoizationFibonacciと同じなので、Map#computeIfAbsentを使えばよさそうですが、
問題はcomputeIfAbsentの第二引数のFunctionです。

とりあえず第二引数にRecursiveFunctionをもらっているので、これを使ってみます。

R apply(T t, RecursiveFunction<T, R> func) {
    return this.map.computeIfAbsent(t, tt -> func.apply(tt, func));
}

うまくいきません・・・。
再帰したときにオリジナルの関数を参照してしまうので、深さ1しかメモ化されていないようです。
再帰時に呼び出す関数を自身に変更してみましょう。

R apply(T t, RecursiveFunction<T, R> func) {
    return this.map.computeIfAbsent(t, tt -> func.apply(tt, this::apply));
}

なんだかうまくいったように見えますが、range(0, 50)をrange(1, 50)にしたらスタックオーバーフローしました。
第二引数のRecursiveFunctionが自身の関数を参照するようになったので、同じところをぐるぐる回っているようです。

正解はこれです。

R apply(T t, RecursiveFunction<T, R> func) {
    return this.map.computeIfAbsent(t, tt -> this.func.apply(tt, this::apply));
}

第二引数のRecursiveFunctionは参照しません。フェイントだ!

騙されてしまったので第二引数は消してしまいましょう。

// コンパイルエラー!!
R apply(T t) {
    return this.map.computeIfAbsent(t, tt -> this.func.apply(tt, this::apply));
}

シグネチャが変わってしまって、自身のメソッド参照が上手くいきません。フェイントだ!

完成


というわけで、最終的にはこうなりました。
implementsする必要はあるようなないような・・・?

public class Memo {
    
    public static void main(String[] args) throws Exception {
        
        RecursiveFunction<Integer, Long> func = 
            (n, f) -> n <= 1 ? n : (f.apply(n - 2, f) + f.apply(n - 1, f));
        
        Memoizer<Integer, Long> memoizer = new Memoizer<>(func);
        IntStream.range(0, 50)
                 .mapToLong(n -> memoizer.apply(n, func))
                 .forEach(System.out::println);
    }
}

@FunctionalInterface
interface RecursiveFunction<T, R> {
    R apply(T t, RecursiveFunction<T, R> func);
}

class Memoizer<T, R> implements RecursiveFunction<T, R> {
        
    private final Map<T, R> map = new HashMap<>();

    private final RecursiveFunction<T, R> func;

    Memoizer(RecursiveFunction<T, R> func) {
        this.func = func;
    }

    @Override
    public R apply(T t, RecursiveFunction<T, R> func) {
        return this.map.computeIfAbsent(t, tt -> this.func.apply(tt, this::apply));
    }
}

関数型が絡むユーティリティはなかなか難しいですね。
でも、書いていて楽しいです。

それではまた。