PyTorch C++の導入〜その2〜

はじめに

深層学習フレームワークPyTorchのインターフェースはPythonであるが、バックエンドはC++である。現在、C++のインターフェース(C++11)も整備されつつある。前回からPyTorch C++(LibTorch 1.2)による実装例の解説を行っている。今回は第2回目である。

コードの説明

本シリーズの目次は以下の通り。今回はこの内2,3,4を説明する。全ソースはここにある。

  1. コード全体の概説
  2. 引数の抽出
  3. デバイスの選択
  4. モデルの定義
  5. データセットの読み込み
  6. 最適化器の準備
  7. 訓練済みモデルのロード
  8. モデルの訓練
  9. モデルの保存

引数の抽出

関数main内の該当部分は以下の通り。

関数parse_argumentsの中身は以下の通り。

名前空間poboost::program_optionsの言い換えである。コマンドライン引数を扱う処理をboostを用いて実装した。

デバイスの選択

関数main内の該当部分は以下の通り。GPUを使える環境であればGPUを使うよう設定する。

3行目のコードはプログラム全体で使う乱数の種を固定する処理なので、GPU云々とは関係ない。

モデルの定義

関数main内の該当コードは以下の通り。

今回使うネットワークは3層の全結合層から構成される単純なものである。3行目で、クラスArchitectureのコンストラクタを呼び出している。これの第1引数は画像を1次元ベクトルに変換したときの次元数、第2引数は分類数(10)である(今回考える問題はCIFAR10を用いた画像の10分類問題である)。4行目でGPUデバイスのメモリにモデルを転送する。3行目でmodelを値として定義しているのに、関数toにアクセスする際になぜ演算子->を用いているのか疑問に思うかもしれない。この種明かしは後述する。5行目の関数は学習で決定されるパラメータ数を表示する。その定義は以下の通り。

model->named_parameters()によりモデル内の全層のパラメータに階層的にアクセスすることができる。この出力は以下の通り。

CIFAR10の画像サイズは32×32である。RGBの3チャンネルあるから1次元ベクトルに変換した時の次元数は32x32x3=3072になる。これを第1層で200次元のベクトルに変換する。第2層では150次元に、第3層で10次元のベクトルに変換する。学習で決まる総パラメータ数は614400+200+30000+150+1500+10=646260である。

ネットワークを定義するクラスは以下の通り(architecture.h)。

ネットワークのアーキテクチャはtorch::nn::Moduleをpublic継承して実装しなければならない。そして、入力値xを処理する関数forwardを実装する(8行目)。関数名に縛りはないがforwardにしておくのが無難である。今回考えるネットワークは3つの全結合層から構成される。従って、これらをインスタンス変数として定義する(11行目から13行目)。全結合層はクラスtorch::nn::Linearとして提供されている。

次に、16行目の記述TORCH_MODULE(Architecture)を説明する。これは、PyTorch C++のソースコードtorch/csrc/api/include/torch/nn/pimpl.hで定義された次のマクロである。

この記述には別のマクロTORCH_MODULE_IMPLが書かれている。その定義もpimpl.hに記載されている。

つまり、TORCH_MODULE(Architecture);は以下のように展開される。

ここで、torch::nn::ModuleHolderは以下で定義されるテンプレートクラスである。

テンプレート引数ContainedArchitectureImplが渡されるので、7行目でstd::shared_ptr<ArchitectureImpl>が定義されることになる。以上をまとめると TORCH_MODULE(Architecture)と言う記述により、クラスArchitecureImplのインスタンスがスマートポインタにより所有者管理されるクラスArchitectureが自動生成されることになる。クラスArchitectureには、そのインスタンスをポインタライクに扱うための関数が定義されているので、先に触れた構文mode->to(device)が書かれることになる。このTORCH_MODULEを用いた仕組みは、PyTorch C++の主要なクラスに踏襲されており、全結合層のクラスtorch::nn::Linearも例外ではない。

ここで1つ注意すべき点がある。ユーザが定義したクラス(ここではArchitectureImpl)にデフォルトコンストラクタが存在しない場合、以下のコードはコンパイルエラーになる。

なぜなら、上のコードは内部でstd::make_shared<ArchitectureImpl>()を実行しており、これはArchitectureImplにデフォルトコンストラクタがない場合エラーを返すためである。空のインスタンスを作りたい場合は以下のようにすれば良い。

クラスModuleHolderには上の構文を受け付けるコンストラクタが用意されている。

クラスArchitectureImplの実装部分(architecture.cpp)は以下の通り。

ここで注意すべき点は、訓練すべき層をregister_moduleで明示的に「登録」することである(6行目から8行目)。これにより、内部で使うパラメータやバッファなどに階層的にアクセスすることが可能となり、全層の重みに対して誤差逆伝播が可能となる。関数forwardの処理内容は以下の通りである。

  1. flatten:入力値xのサイズ[batch_size, channels, rows, cols]を[batch_size, channels * rows * cols]に変換する。
  2. 第1層の出力に活性化関数ReLUを適用する。
  3. 第2層の出力に活性化関数ReLUを適用する。
  4. 第3層の出力に活性化関数Softmaxを適用し対数をとる。

まとめ

今回は、目次の項目2,3,4を説明した。PyTorch C++の主要なクラスはstd::shared_ptrで管理されること、学習される重みを持つインスタンスはregister_moduleを使って明示的に登録する必要があることを説明した。次回も引き続きPyTorch C++の説明を行う。

Kumada Seiya

Kumada Seiya

仕事であろうとなかろうと勉強し続ける、その結果”中身”を知ったエンジニアになれる

最近の記事

  • 関連記事
  • おすすめ記事
  • 特集記事

アーカイブ

カテゴリー

PAGE TOP