Island Babies

いろんなことを書きます

関数をキャッシュする : lru_cache

Decoratorについて

Decoratorは、大雑把に言ってしまうと高階関数をシンプルに書くための記法です。Pythonでは関数もまた第1級関数であるため、関数を引数にとったり、関数を返値として返すことができます。関数を引数にとるような関数のことを高階関数と呼びます。有名なところでは、mapfilter高階関数の例として挙げられます。

list(map(fact, range(6))) # => [1, 1, 2, 6, 24, 120]
list(filter(lambda x: x < 0, range(-5, 5))) #=> [-5, -4, -3, -2, -1]

pythonではmapやfilterの返値はイテレータなので、例えばlistなどで受けてあげることになります。ちなみに、通常このような処理は通常内包表記で書けるので、できればmapやfilterで書くべきではないですし、内包表記で書くように設計されています。 さて、自分で高階関数を書くことを考えます。例えば、再帰的に関数を呼び出すような関数に対して処理を見やすくするための関数を考えてみましょう。関数を実行するときに、関数の名前や引数を表示してくれるような次のような関数を書くことを考えます。

def track(func):                                                                                                                                                        
    def tracked(*args, **kwargs):                                                                                                                                       
        result = func(*args, **kwargs)                                                                                                                                  
        name = func.__name__                                                                                                                                            
        arg_str = ",".join(repr(arg) for arg in args)                                                                                                                   
        print("%s(%s) -> %r" % (name, arg_str, result))                                                                                                                 
        return result                                                                                                                                                   
                                                                                                                                                                                  
    return tracked  

この関数はある関数を受け取って、その関数を引数に適用させ、関数の名前、引数、返値を表示したあとでその適用結果を返すような関数になっています。例えば、フィボナッチ数列に適用させるとき、次のようにかけます。

@track
def fibonacci_naive(n):
    if n< 2: 
        return n
    return fibonacci_naive(n-2) + fibonacci_naive(n-1)

結果は次のように出力されます。

fibonacci_naive(0) -> 0                                                                                                                                             
fibonacci_naive(1) -> 1                                                                                                                                                                
fibonacci_naive(2) -> 1                                                                                                                                             
fibonacci_naive(1) -> 1                                                                                                                                             
fibonacci_naive(0) -> 0                                                                                                                                             
fibonacci_naive(1) -> 1                                                                                                                                                                    
fibonacci_naive(2) -> 1                                                                                                                                             
fibonacci_naive(3) -> 2                                                                                                                                             
fibonacci_naive(4) -> 3                                                                                                                                             
fibonacci_naive(1) -> 1                                                                                                                 
fibonacci_naive(0) -> 0                                                                                                                 
fibonacci_naive(1) -> 1                                                                                                                 
fibonacci_naive(2) -> 1                                                                                                                 
fibonacci_naive(3) -> 2                                                                                                                 
fibonacci_naive(0) -> 0                                                                                                                 
fibonacci_naive(1) -> 1                                                                                                                 
fibonacci_naive(2) -> 1                                                                                                                 
fibonacci_naive(1) -> 1                                                                                                                 
fibonacci_naive(0) -> 0                                                                                                                 
fibonacci_naive(1) -> 1                                                                                                                 
fibonacci_naive(2) -> 1                                                                                                                 
fibonacci_naive(3) -> 2                                                                                                                 
fibonacci_naive(4) -> 3                                                                                                                 
fibonacci_naive(5) -> 5                                                                                                                 
fibonacci_naive(6) -> 8   

うまく表示されましたね。ここの@trackがデコレータと呼ばれる処理になります。真面目にこの処理を書こうと思ったら、

def fibonacci_naive(n):                                                                                                                                             
    if n < 2:                                                                                                                                                       
        return n                                                                                                                                                    
    return fibonacci_naive(n-2) + fibonacci_naive(n-1)  

fibonacci_naive = track(fibonacci_naive)

こーんな書き方をしないといけません。とってもダサいですね。と、いうわけで、まず単純に関数を引数とするような処理が書きやすいというのがあります。もう一つ良い理由として、二つ以上の関数を関数に適用させるようなときにもスマートにかけます。これについてはもうすぐ出てきます。

このような書き方を導入しているのだから当然ライブラリとしていくつかの便利なデコレータが提供されていそうだなーと思っていると便利なものがありました。

functools.lru_cache()

この関数は、大雑把に言ってしまうとメモ化をしてくれるようなデコレータになります。公式ドキュメントの説明では、

Decorator to wrap a function with a memoizing callable that saves up to the maxsize most recent calls. It can save time when an expensive or I/O bound function is periodically called with the same arguments.

という風になっています。最近呼ばれた関数の値をメモ化し、キャッシュしておくデコレータになります。例えば次のように使います.

import functools

@functools.lru_cache()                                                                                                                                              
@track                                                                                                                                                              
def fibonacci_memonized(n):                                                                                                                                         
    if n < 2:                                                                                                                                                       
        return n                                                                                                                                                    
    return fibonacci_memonized(n-2) + fibonacci_memonized(n-1)  

デコレータがあるおかげでスッキリかけています!本来なら二回関数に包むように書かないといけないですからね、だいぶスマートにかけているのではないでしょうか。比較のためにlru_cacheを使わないようなフィボナッチ関数と比較してみます。次のようなコードを実行してみましょう。

import functools                                                                                                                                                    
                                                                                                                                                                                      
def track(func):                                                                                                                                                        
    def tracked(*args, **kwargs):                                                                                                                                         
        result = func(*args, **kwargs)                                                                                                                                         
        name = func.__name__                                                                                                                                            
        arg_str = ",".join(repr(arg) for arg in args)                                                                                                                                         
        print("%s(%s) -> %r" % (name, arg_str, result))                                                                                                                                         
        return result                                                                                                                                                   
                                                                                                                                                                                  
    return tracked                                                                                                                                                   
                                                                                                                                                                                  
@functools.lru_cache()                                                                                                                                              
@track                                                                                                                                                              
def fibonacci_memonized(n):                                                                                                                                         
    if n < 2:                                                                                                                                                       
        return n                                                                                                                                                    
    return fibonacci_memonized(n-2) + fibonacci_memonized(n-1)                                                                                                                                         
                                                                                                                                                                    
@track                                                                                                                                                              
def fibonacci_naive(n):                                                                                                                                             
    if n < 2:                                                                                                                                                       
        return n                                                                                                                                                    
    return fibonacci_naive(n-2) + fibonacci_naive(n-1)                                                                                                                                         
                                                                                                                                                                    
print("-"*10, "naive-fibonacci", "-"*10, "\n")                                                                                                                                         
fibonacci_naive(6)                                                                                                                                                  
print("\n")                                                                                                                                                         
                                                                                                                                                                    
print("-"*10, "memonized-fibonacci", "-"*10, "\n")                                                                                                                                         
fibonacci_memonized(6)

これらを実行すると、次のように表示されます。

---------- naive-fibonacci ----------                                                                                                                               
                                                                                                                                                                                               
fibonacci_naive(0) -> 0                                                                                                                                             
fibonacci_naive(1) -> 1                                                                                                                                                                
fibonacci_naive(2) -> 1                                                                                                                                             
fibonacci_naive(1) -> 1                                                                                                                                             
fibonacci_naive(0) -> 0                                                                                                                                             
fibonacci_naive(1) -> 1                                                                                                                                                                    
fibonacci_naive(2) -> 1                                                                                                                                             
fibonacci_naive(3) -> 2                                                                                                                                             
fibonacci_naive(4) -> 3                                                                                                                                             
fibonacci_naive(1) -> 1
fibonacci_naive(0) -> 0
fibonacci_naive(1) -> 1
fibonacci_naive(2) -> 1
fibonacci_naive(3) -> 2
fibonacci_naive(0) -> 0
fibonacci_naive(1) -> 1
fibonacci_naive(2) -> 1
fibonacci_naive(1) -> 1
fibonacci_naive(0) -> 0
fibonacci_naive(1) -> 1
fibonacci_naive(2) -> 1
fibonacci_naive(3) -> 2
fibonacci_naive(4) -> 3
fibonacci_naive(5) -> 5
fibonacci_naive(6) -> 8


---------- memonized-fibonacci ---------- 

fibonacci_memonized(0) -> 0
fibonacci_memonized(1) -> 1
fibonacci_memonized(2) -> 1
fibonacci_memonized(3) -> 2
fibonacci_memonized(4) -> 3
fibonacci_memonized(5) -> 5
fibonacci_memonized(6) -> 8

キャッシュがいい感じにされていて呼び出し回数が最小限になっていますね。ちなみにデフォルトだと128個まで保持するようです。引数を指定すれば無制限にキャッシュを保持するようにもできますが、下手な処理でそんなことをするとデータが溢れかえりそうなので気をつけないといけないですね。

おわり

こういうメモ化は結構使う処理なのでこれだけシンプルにかけるととても便利ですね。functoolsには他にも関数を部分適用する関数とかもあって面白そうなので1度目を通してみたいですね。なお、内容としてはfluent pythonを読んでいてこのあたりの話を読んでまとめておこうと思いまとめたものになります。

備考

この記事は(https://qiita.com/Akutagawa/items/dbb035f117c764409281)をそのままコピーして少しいじったものになっていますが、本人です。

参考

Fluent Python ―Pythonicな思考とコーディング手法

公式ドキュメント(functools)