PyTorch C++の導入〜その1〜

はじめに

深層学習フレームワークPyTorchのインターフェースはPythonであるが、バックエンドはC++である。現在、C++のインターフェース(C++11)も整備されつつある。今回から3回に分けてPyTorch C++(LibTorch 1.2)による実装例の解説を行う。

なぜC++を使うのか?

主な理由は以下になるだろう。

  • Pythonを使えない環境で深層学習を行う必要があるから。
  • 実時間性がクリティカルになる環境で深層学習を行う必要があるから。
  • マルチスレッド環境で深層学習を行う必要があるから。Pythonは、GIL(Global Interpreter Lock)のため2つ以上のスレッドを同時に走らせることができない。
  • C++で実装された既存システムとシームレスに接続し深層学習を行う必要があるから。

個人的な理由は以下の通りである。

  • 私のホームグランドがC++であるから。
  • Pythonばかりの日常に少し飽きてきたから。

インストール方法

ここの通りにすれば良い。実装したプログラムのコンパイルにはcmakeを使うので、あらかじめこれをインストールしておく必要がある。今回のプログラムの動作確認は、EC2インスタントの1つであるAWS深層学習AMI上で行った。

やりたいこと

データセットCIFAR10を用いた10分類問題である。具体的には、この本の2章にあるKeras実装例を書き換え、以下のようなコマンドラインアプリを作る。GPU上でも動作するアプリである。

引数の意味は以下の通り。

  • –batch_size:バッチサイズを指定する。
  • –epochs:エポック数を指定する。
  • –resume:訓練済みモデルを用いて学習を途中から再開したいときはtrueを指定する。
  • –model_path:訓練後のモデルを保存するパスを指定する。
  • –opt_path:訓練後の最適化器を保存するパスを指定する。
  • –verbose:ログ出力する場合はtrueを指定する。
  • –trained_model_path:訓練済みモデルへのパスを指定する。
  • –trained_opt_path:訓練済み最適化器へのパスを指定する。

最後の2つのパスは、–resumeがtrueのときに使う。

コード説明

次の目次に沿って説明する。全ソースはここにある。

  1. コード全体の概説
  2. 引数の抽出
  3. デバイスの選択
  4. モデルの定義
  5. データセットの読み込み
  6. 最適化器の準備
  7. 訓練済みモデルのロード
  8. モデルの訓練
  9. モデルの保存

今回は、最初の「コード全体の概説」を行う。それ以外の項目については次回以降に説明する。

コード全体の概説

関数mainの中身は以下の通り(main.cpp)。

  • 5行目から20行目まででコマンドラインの引数を抽出する。
  • 24行目の関数manual_seedで乱数の種を固定する。
  • 25行目から36行目まででGPUが使えるかを判断する。
  • 40行目でネットワークを定義する。
  • 41行目のtoメソッドでモデルをGPUデバイス側へ転送する。

  • 46行目から62行目まででデータセットを読み込む。
  • 66行目から69行目までで最適化器を定義する。
  • 73行目から80行目までで、訓練済みモデルが指定されていればこれを読み込む。
  • 83行目から87行目まででモデルを訓練(関数train)・評価(関数test)する。
  • 93行目からの2行でモデルを保存する。

バッチサイズ32、エポック数10のときの訓練時間は26秒程度となった。一方、Kerasのコードでの訓練時間は70秒程度である。どちらの実行でもログ出力をオフにして計測を行った。PyTorch C++版の方が2倍以上速いことが分かる。Kerasのコードでは学習時に関数fitを呼び出している。この関数はマルチプロセスを実現する引数workersを持たない。一方、今回示したC++の実装では、データローダ作成時にworkersに2を渡し、マルチスレッドによるデータ供給を行わせた。この違いが速度に現れているのかもしれない。テストデータに対する精度は、Keras版が0.5032、PyTorch C++版が0.506なので差はほとんどない。乱数で10分類したときの精度(0.1)よりは5倍ほど良い。

まとめ

今回はコードの全体像と計算時間、精度について説明し、PyTorch C++版の方がKeras版より高速に処理されることを見た。バックエンドではどちらもGPUを使うので、純粋にCPU側の差であろう。ただし、上で見たように、現在のコードでKeras版とC++版の速度を比較することはフェアでない。実は同じ本の3章にあるCNNを含むネットワークもPyTorch C++で実装したのだが、Keras版と速度的には互角であった。バックエンドにGPUを使う場合、訓練速度の優劣を評価するのはナンセンスなのかもしれない。

次回以降で、上に掲げた目次に沿ってコードの中身を説明する。

Kumada Seiya

Kumada Seiya

仕事であろうとなかろうと勉強し続ける、その結果”中身”を知ったエンジニアになれる

最近の記事

  • 関連記事
  • おすすめ記事
  • 特集記事

アーカイブ

カテゴリー

PAGE TOP