PyTorch C++の導入〜その3〜

はじめに

深層学習フレームワークPyTorchのインターフェースはPythonであるが、バックエンドはC++である。現在、C++のインターフェース(C++11)も整備されつつある。前々回前回からPyTorch C++(LibTorch 1.2)による実装例の解説を行っている。今回は第3回目、最終回である。

コードの説明

本シリーズの目次は以下の通り。今回はこの内5以降を全て説明する。全ソースはここにある。

  1. コード全体の概説
  2. 引数の抽出
  3. デバイスの選択
  4. モデルの定義
  5. データセットの読み込み
  6. 最適化器の準備
  7. 訓練済みモデルのロード
  8. モデルの訓練
  9. モデルの保存

データセットの読み込み

関数main内の該当部分は以下の通り。

3行目の関数load_datasetの返り値はクラスCustomDatasetのインスタンスtrain_datasetである。このインスタンを引数にして、13行目でtrain_loaderを作成する。関数torch::data::make_data_loaderは、テンプレート関数であり、テンプレート引数として、RandomSamplerSequentialSamplerを取り得る。前者は、データをランダムに返すので、訓練時に使うことができる。後者はデータをそのまま返すので、評価時に使うことができる(17行目)。
クラスCustomDatasetは、ユーザ定義のデータセットを使用するため定義したクラスである(custom_dataset.h)。その中身は以下の通り。

データセットを扱うクラスはテンプレートクラスtorch::data::Datasetをpublic継承し、2つの仮想関数getsizeを実装しなければならない。前者は指定したインデックスにおける画像とラベルを対にしたものを返す関数、後者はデータ数を返す関数である。前者の値を返す部分(16行目)でcloneを使っていることに注意する。クラスtorch::Tensorのコピーコンストラクタは浅いコピーを行う実装になっている。深いコピーを返す必要がある場合は明示的にcloneを呼ぶ必要がある。

ところで、上の親クラスのテンプレート引数が子クラス自身になっていることに気付いただろうか。これはC++の有名なidiomであり、Curiously Recurring Template Pattern(CRTP)と言う名前が付いている。静的にポリモーフィズムを実現する仕組みであり、例えば以下のように使うことができる。

これらを以下のコードで実行すると

以下の出力を得る。

子クラスにimplementationが実装されていればそれが実行され、実装されていなければ親クラスの実装が実行される。つまり、動的ポリモーフィズムと同じ振る舞いを静的に実現できるのである。動的ポリモーフィズムと異なり、仮想関数テーブルなどのオーバヘッドが存在しないので、高速に動作する。

クラスCustomDatasetの実装部分(custom_dataset.cpp)では、CIFAR10からダウンロードしたバイナリファイルを読み込む作業を行う。煩雑なので掲載は割愛する。ソースを見て欲しい。

話が長くなったのでもう一度関数main内のデータ読み込み部分を示す。

関数load_datasetの返り値に対し、関数mapを2度呼び出している(4,5行目)。最初のmapは、データの値を255で割り、値を[0,1]に収める処理である。2番目のmapは、バッチ単位のデータセットの持ち方を指定している。上のようにStackを呼び出すと、画像をバッチ数だけ集めたコンテナAとラベルをバッチ数だけ集めたコンテナBをペアにした構造が作られる。一方、2番目のmapを呼ばない場合、画像とラベルのペアをバッチ数だけ集めた構造が作られる。前者の方が訓練時のコードが書きやすい。3行目のtrain_datasetが出来上がったら、torch::data::make_data_loaderを使って、データローダtrain_loaderを作る(13行目)。データローダでデータを取り出しながら訓練を行うことになる。torch::data::make_data_loaderの第2引数にはバッチサイズとスレッド数が渡されている。

最適化器の準備

関数mainの該当部分は以下の通り。

第1引数でモデルの全パラメータを、第2引数に学習率LEARNING_RATEを渡している。

訓練済みモデルのロード

関数mainの該当部分は以下の通り。

訓練済みモデルを読み込んで訓練の続きを行う場合は、モデルだけでなく、最適化器のインスタンスoptimizerも読み込む必要があることに注意する。

モデルの訓練

関数mainの該当部分は以下の通り。

関数trainの中身は以下の通り(main.cpp)。

  1. 11行目:訓練モードに設定する。他に評価モードが存在する(後述)。
  2. 13行目:データローダからバッチ単位でデータセットを取り出す。
  3. 15,16行目:画像とラベルをGPUデバイス側へ転送する。
  4. 17行目:最適化器の初期化。
  5. 19行目:ネットワークに入力を与え順伝播させる。
  6. 20行目:出力(10次元ベクトル)の最大要素のインデックスを求める。これが予測したラベルである。
  7. 24行目:正解ラベルと予測ラベルの一致度を計算する。一致すれば1、不一致なら0をバッチ数分だけ足す。あとでバッチ数で割り、平均値が計算される。
  8. 26行目:torch::nll_lossはNegative Log Likelihood Loss関数である。これが、最小にすべき目的関数である。
  9. 29,30行目:偏微分を行い、誤差逆伝播を行う。
  10. 32行目から40行目:ログ出力

関数testの中身は以下の通り。

  1. 10行目:評価モードに設定する。
  2. 13行目:データローダからバッチ単位でデータセットを取り出す。
  3. 15,16行目:画像とラベルをGPUデバイス側へ転送する。
  4. 17行目:ネットワークに入力を与え順伝播させる。
  5. 20行目から24行目:損失を計算する。
  6. 25行目:予測ラベルを求める。
  7. 26行目:予測値と正解値が一致した数を累積する。
  8. 29行目:損失値のバッチサイズでの平均値を求める。

モデルの保存

関数main内の該当部分は以下の通り。

後で学習を再開するにはmodelだけでなく、optimizerも保存しなければならない。

まとめ

今回を含めて3回に渡ってPyTorch C++による実装例を見てきた。ホームページを見ると、Pythonインタフェースから予測できるようなC++インタフェースを整備していると書かれており、確かにその印象を受けた。「これどうやって書くんだ?」と思った時は大抵、Pythonインタフェースからの推測で当たることが多い。
クラスのインスタンスがstd::shared_ptrで所有者管理されること、データセットを扱うときCRTPと呼ばれる静的ポリモーフィズムが使われることなど、C++としても興味深い機能が採用されていることを見た。実務に使うか否かは別にして、C++で深層学習を行う1つのツールとして紹介した。ただし、コード記述と実行の間にコンパイル作業が入るので、スクリプト言語に慣れた身にとっては面倒臭いことは確かである。

Kumada Seiya

Kumada Seiya

仕事であろうとなかろうと勉強し続ける、その結果”中身”を知ったエンジニアになれる

最近の記事

  • 関連記事
  • おすすめ記事
  • 特集記事

アーカイブ

カテゴリー

PAGE TOP