Pythonのfor文は遅い?

bicycle1885.hatenablog.com

こちらの記事を拝見していて、ちょっと気になったので注釈。

PythonやRを使っている人で、ある程度重い計算をする人達には半ば常識になっていることとして、いわゆる「for文を使ってはいけない。ベクトル化*1しろ。」という助言があります。 これは、PythonやRのようなインタープリター方式の処理系をもつ言語では、極めてfor文が遅いため、C言語Fortranで実装されたベクトル化計算を使うほうが速いという意味です。

昔からよくこういう言い方がよくされるが、本当にPythonのfor文は遅いのだろうか。

聞くところによるとRのfor文はガチで遅いそうだが、Pythonの計算が遅いのはインタープリタ方式だからでも、for文が遅いからでもない。もちろん、Pythonインタープリタなので遅いし、for文だって極めて遅い。しかし、これはPythonの計算が遅い要因の一部でしかない。

まずは手元の環境(Macbook Air 2015, Python 3.6)で速度を測ってみよう。以下のコードはすべて Jupyter Notebookで実行している。

import numpy as np
a = np.ones(100000)
b=np.ones(100000 )

def dot(a, b):
    s = 0
    for i in range(len(a)):
        s += a[i] * b[i]
    return s

timeit dot(a, b)
41.6 ms ± 1.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

timeit np.dot(a, b)
62 µs ± 5.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Pythonのループを使った演算と、numpyを使った演算ではパフォーマンスに大きな差がある。これは for文が遅いから なのだろうか?

試しに、演算をせずにforループだけを実行してみよう。

def loop(a, b):
    s = 0
    for i in range(len(a)):
        pass
    return s

timeit loop(a, b)
3.44 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

演算を行わず、forループを実行するだけなら、全体の10%以下しかかかっていない。まあ、forループ遅いが、全体の遅さの主犯ではないようだ。

では、Pythonの遅さの残り9割はどこからくるのだろう?

ここからは、Cythonを使って原因を探っていこう。

まず、Cythonで dot()C言語に変換する。

%%cython
def dot_cython(a,b):
    s = 0
    for i in range(len(a)):
        s += a[i] * b[i]
    return s

timeit dot_cython(a,b)
21.9 ms ± 717 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

単純にC言語に変換しただけでは、それほど変化はない。

ここでは、dot_cython()for ループは、CythonによってC言語のループに展開されており、Pythonのようなループによるオーバーヘッドはなくなっている。

また、Pythonバイトコードを経由せずに実行しているため、Pythonインタープリタのオーバヘッドはなくなっている。処理時間が 41.6 ms -> 21.9 ms と約半分になっているが、これはほぼインタープリタのオーバヘッドが解消したためだ。

ここでわかるのは、単純にPythonと同じ処理をC言語で書き直すだけでは、numpyの62µsという圧倒的な速度には遠くおよばない、ということだ。 Pythonインタープリタである、というのも、Pythonの遅さの原因の一部でしかないのである。

インタープリタ言語というと、何もかもがコンパイル型言語より数百倍遅くなるようなイメージがあるかもしれないが、単純な処理ならそれほど遅くはならないことが多い。

Pythonの遅さの別の原因として、Pythonが静的な型定義を持たない、という点がある。

例えば、dot() では s += a[i] * b[i]という式を実行しているが、この中の X*Y のような乗算処理では、次のような処理が行われる。

  1. X が乗算をサポートしているかチェックする
  2. X の乗算関数を取得する
  3. Y が被乗数として適切なデータかチェックする
  4. X の値と Y の値を取り出し、乗算する
  5. 乗算の結果から新しい浮動小数点数オブジェクトを作成する

しかし、C や Javaのような、静的な型定義をもつプログラミング言語では、そもそも乗算を行えないような処理はコンパイルエラーとなるため、 上記の1. 〜 3. の処理の必要がなく、さまざまな最適化を行って処理を高速化できる。

Cythonでは、明示的にC言語のデータ型を指定して値を変換できる。まず、数値演算処理の部分にデータ型を宣言し、高速化してみよう。

%%cython
def dot_typed(a,b):
    cdef double s = 0.0
    for i from 0 <= i < len(a):
        s += <double>(a[i]) * <double>(b[i])
    return s

timeit dot_typed(a,b)
11.8 ms ± 189 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

この宣言により、次のような高速化が行われる。

  1. 上記の1. 〜 3. の処理は必要ない。
  2. 上記の 4. の演算処理は、Pythonの加算処理と乗算処理ではなく、CPUの機能で高速に演算されるようになる。
  3. 上記の 5. の浮動小数点数オブジェクトを生成せず、ハードウェアがサポートしている浮動小数点数が作成される。

だいぶ速くなったが、やはりまだ Numpy には及ばない。

dot_typed() では、演算中に <double>(a[i]) のようにして、Numpyの配列から要素を取得して、C言語double 型に変換している。実は、これもかなり複雑な処理なのだ。

  1. a が添字によるインデックスをサポートしているかチェックする
  2. a のインデックス関数を取得する
  3. 添字 i が添字として適切かチェックする
  4. i の整数値を取得する
  5. a から i 番目の値を取得する
  6. 取得した値で、numpy.float64 オブジェクトを作成する
  7. numpy.float64 オブジェクトを double 型に変換可能かチェックする
  8. numpy.float64 オブジェクトの変換関数を取得する
  9. double 型に変換する

この処理をスキップして、Numpy配列からデータを直接取得してみよう。

実は、Numpy配列には double 型のデータが格納されており、適切なデータ型を指定して直接参照してしまえば、変換は一切必要なくなってしまう。Numpy内部では、このような形式で要素を参照して効率的に処理を行えるようになっている。

Cythonには、Numpyなどのバッファを直接参照する、Typed Memoryview 型が備わっている。この機能で、単なる double 型データの入った配列としてアクセスできるようにしてみよう。

%%cython
def dot_view(double[:] a, double[:] b):
    cdef double s = 0.0
    for i from 0 <= i < len(a):
        s += a[i] * b[i]
    return s

timeit dot_view(a,b)
152 µs ± 6.32 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

おお。速くなった。データをPythonプロトコルを使って取り出す部分が最後のボトルネックだったわけで、この部分を最適化することで、Numpyの4割ぐらいのパフォーマンスまで迫ることができた。

のこりの速度差は、おそらくNumpy内部の内積処理の、高度な最適化によるものだと思う。もう一段階せまってみようかと思ったが、面倒なのでやめた。

つまり

  • Pythonを使った処理は遅くなるが、インタープリタだから、というのは実はそれほど大きな要因ではない。

  • Pythonの演算が遅い最大の要因は、Pythonが静的な型宣言を行わない言語で、型推論JITもなく、常に動的にオブジェクトの演算を行う、という点にある場合がほとんどだ。

  • Numpyでは、配列をすべて同じデータ型しか格納できない、 Homogeneous なコンテナとすることで、効率的に計算を行えるようにしている。

  • Numpy 「は?型推論?ぜんぶfloatにすればよくね?」

ついでに

Pythonが遅い原因として、GIL(Global Interpreter Lock) によってマルチコアをうまく使えないから、と言われることもある。

これもまあ、なくはない。

しかし、仮にGILがなくとも、Pythonの演算はせいぜいCPU数分しか速くならない。CPUが16個あってもたかだか16倍になるにすぎない。これではとてもNumpyには対抗できないのである。

おまけに

この根本的な原因は、今のJulia(v0.3.7)には破壊的な演算子が無いため、いつでも新しい配列を確保してしまう点にあります。

Pythonには、破壊的な代入演算子( += など) がある。

石本敦夫氏に聞く、Pythonの歴史とこれから〜Pythonエンジニア列伝 Vol.3 - PyQオフィシャルブログ でも話したが、Pythonにこの種の代入演算子が導入されたのは、実はNumpyで使用するためだった。

Python1.5までは、+= の導入には否定的な意見が多かった。これは、

X = [1,2,3]
X += [4,5]

のように、リストなどの更新可能なオブジェクトなら、リストオブジェクト X に新しく要素を追加すればよい。

しかし、同じようなスクリプトでも、

X = (1,2,3)
X += (4,5)

では、Xは更新不可能なタプルオブジェクトなので、要素を追加できない。この場合は、X に要素が追加されるのではなく、新しく (1,2,3,4,5) というタプルオブジェクトが、X に代入されることになる

このような判りにくさから、+= は導入されないという判断がくだされていた。

しかし、Numpyで大きな配列を効率的に演算するため、ということで必要性を認められ、導入されたのである。