JAX

はじめに

 深層学習フレームワークの主流は、TensorFlowPyTorchの2択に収束しつつある。前者はGoogleが、後者はFacebookが開発したオープンソースである。ところが最近、JAXと呼ばれる新興ライブラリがじわじわと勢力を広げてきた。開発元はTensorFlowと同じGoogleである。Googleは、たとえ社内のパイを食い合うことになろうと、次々と新しいものを作る社風だ。今回は、このJAXを取り上げる(先日、PyTorchもJAXに触発されてJAXライクなライブラリfunctorchをリリースした)。

JAXとは

 JAXの特徴は以下の通り。

  • NumPyと互換なAPIを持つ。
  • 自動微分の機能を持つ。
  • コードを動的にコンパイル(JITコンパイル)でき、CPU/GPU/TPUごとに最適化できる。

  • JAXは深層学習に特化したライブラリではなく、その基盤となるライブラリである。すでにJAXを利用した深層学習や深層強化学習のフレームワークが作られている。

    問題設定

     具体的な問題を与え、老舗のフレームワークPyTorchとの速度比較を行う。対象とする問題は簡単な単回帰である。

     観測データD=\{(x_1,y_1)\cdots,(x_N,y_N)\}\;(x_n,y_n\in\mathbb{R})が与えられたとき、xyの間の関係を次の形で求めたい。

    (1)    \begin{align*} y=ax+b \end{align*}

    損失関数として次式を考える。

    (2)    \begin{align*} L(a,b)=\sum_{n=1}^N\left(y_n-(ax_n+b)\left)^2 \end{align*}

    この関数の値を最小にするようなパラメータa,bを、勾配降下法により求める。各パラメータの更新式は次の通りである。

    (3)    \begin{align*} a&\leftarrow a - \mu\frac{\partial L(a,b)}{\partial a}\\ b&\leftarrow b - \mu\frac{\partial L(a,b)}{\partial b} \end{align*}

    ここで、\muは学習率と呼ばれる正の微小量である。損失関数の変動がなくなるまで上の更新を繰り返す。

     ここからは具体的な実装例をコードを抜粋して示す。全ソースはここにある。

    観測データの作成

     以下のコードで観測データを作成した。

    jnpjax.numpyの別名である。PyTorch版、JAX版のどちらの関数もNumPyでデータを作成し、最後にそれぞれのライブラリで使えるインスタンスに変換している(6行目)。今回は、a=5,b=2とし、そこにノイズを追加してある(下図参照)。

    学習コード:PyTorchの場合

    一般的なPyTorchの学習手順である。

  • 8行目:\frac{\partial L(a,b)}{\partial a}\frac{\partial L(a,b)}{\partial b}のそれぞれを0に初期化する。
  • 9行目:式(1)を計算する。これが予測値y_predである。
  • 10行目:式(2)を計算する。
  • 11行目:誤差逆伝播が実行され\frac{\partial L(a,b)}{\partial a}\frac{\partial L(a,b)}{\partial b}が計算される。
  • 12行目:式(3)でパラメータが更新される。lrは学習率である。
  • 学習コード:JAXの場合

    JAX版に対しては以下2つの場合の速度比較を行った。

  • 学習ループを包含するjax.lax.fori_loopを使う方法(9行目)
  • 学習ループをそのまま書く方法(13行目から14行目)

  • 先に後者の説明を行う。ここで呼び出している関数train_の中身は以下の通り。

  • 1行目:デコレーターjax.jitを付けるとJITコンパイルされる。
  • 4行目:\frac{\partial L(a,b)}{\partial a}\frac{\partial L(a,b)}{\partial b}を計算する。
  • 6行目:式(3)を計算し、パラメータを更新する。

  • 4行目のgrad_lossの正体は以下の通り。

    関数lossは式(2)を実装したものである。この関数をjax.gradに渡し、paramslossの第1引数)で微分させている(6行目)。paramsの中身は以下の通り。

    つまり、a,bL(a,b)が微分される。

     次に、学習ループを包含する関数jax.lax.fori_loopを使う方法を解説する。以下の9行目である(再掲)。

    関数trainの中身は以下の通り。

    関数内関数body_funを定義している。これは先に見たtrain_の中身と同じである。関数body_funは、10行目のjax.lax.fori_loopに渡される。実はこの関数は次のコードと同じ意味を持つ。

    詳細なからくりは調べていないが、ループをjax.lax.fori_loopに置き換えると大変顕著に高速化されることを次に示す。

    速度比較

     実行環境はAWS EC2インスタンス p3.2xlargeである。そのスペックは以下の通り。

  • Tesla V100 1個
  • GPUメモリ 16GB
  • vCPU 8個
  • メモリ 61GB

  • 学習時間の一覧は以下の通り。

    PyTorchとJAXのいずれに対してもGPUを有効にしてある。PyTorchでもJITを使うことができるらしいが、今回は割愛した(コードを全面的に書き直す必要があるようだ)。上の結果を見ると、最速は、JITなしのJAX(fori_loop)版であることが分かる。これは少し意外である。素直に考えれば、JITありのJAX(fori_loop)版になりそうだからだ。おそらく、今回の計算処理は軽いので、動的コンパイルのためのオーバヘッドの方が大きかったのだろう。JITコンパイルの有無に関わらずjax.lax.fori_loopは積極的に使うべきである。

     最後に、予測された直線とa,bの値(小数点以下4位で四捨五入した)を示す(下図参照)。正解値はa=5,b=2である。

    参考文献

  • Quora
  • HELLO CYBERNETICSさんのブログ
  • まとめ

     今回は、最近流行のJAXを取り上げ、PyTorchとの速度比較を行った。JAXは確かに速いことが分かった。JAXのJITコンパイルに使われるコンパイラーはGoogleが開発したものでありXLAと呼ばれる。気になる方は調べてほしい。また、JITコンパイルの有無に関わらず、ループを特定の関数で置き換えると大変高速になることも示した。
     ところで、JAXは何の略なのか。「What does JAX stand for?」でググったが良く分からなかった。jax.lax.fori_loopのforiの由来も分からない。

    Kumada Seiya

    Kumada Seiya

    仕事であろうとなかろうと勉強し続ける、その結果”中身”を知ったエンジニアになれる

    最近の記事

    • 関連記事
    • おすすめ記事
    • 特集記事

    アーカイブ

    カテゴリー

    PAGE TOP