Chainerによる転移学習とファインチューニングについて(VGG16、ResNet、GoogLeNet)



画像系の深層学習では、学習済みモデルの重みを利用する「転移学習」や「ファインチューニング」と呼ばれる手法がよく利用されます。

  • 転移学習: 学習済みのモデルから特徴量を抽出すること
  • ファインチューニング: 学習済みモデルの重みを使って再学習させること

どちらも基本的には、ILSVRCなどの画像認識コンペで優秀な成績を収めたモデルのネットワークアーキテクチャを深層学習のライブラリで構築し、公開されている学習済みの重みファイルを読み込ませて利用するという流れで実装します。

Chainerでは、以下の画像認識モデルが、すでに内部で実装されています。

  • VGG16
  • ResNet50, ResNet101, ResNet152
  • GoogLeNet

また、これらのモデルに学習済みの重みファイルを読み込ませるための便利な関数が一通り揃っていますので、それらの使い方についてまとめます。

今回扱うChainerのバージョンは 3.0.0 とします。

転移学習で画像特徴量を抽出する

Chainerから呼び出せる学習済みモデルはいずれも、ImageNetの1000クラス分類のタスクを学習したモデルになります。

今回のテスト用に入力する画像として、「パグ」の画像データを1枚用意しました。

1000クラスのラベルに「パグ」が含まれており、クラス番号は 254 になります。

Chainerから呼び出した学習済みモデルが、ちゃんとこの画像を入力して 254 を予測するか確認してみます。

VGG16

VGG16の学習済みモデルを呼び出すには、 chainer.links.VGGLayers 関数を使います。

Chainerの実装では chainer.links.VGGLayers 関数の引数に pretrained_model が指定でき、何も指定しない場合にはデフォルトで学習済みモデルの重みを自動でダウンロードする仕組みになっています。

上記で、 重みを反映させたモデルインスタンス vgg16 を宣言したことになります。

実際に画像を入力して推論させる時は、モデルインスタンスの __call__ 関数に入力データをセットして実行することになります。

VGG16に限らず、ネットワークによって、入力する画像をそのネットワーク独自の前処理を行う場合がほとんどです。

例えば、VGG16の場合は入力画像に対し、

  • (batch_size, chennels, height, width) = (batch_size, 3, 224, 224) に合わせる
  • カラーのチャンネルの順番は (BGR) にする
  • 各画素値から平均値 (103.939, 116.779, 123.68) を引く

といった処理を施した後、ネットワークで推論させます。

Chainerでは、Pillow で読み込んだ画像データに対して、上記の処理を自動でやってくれる chainer.links.model.vision.vgg.prepare 関数があります。

推論した結果には、クラス分類の最終レイヤーである全結合層のベクトル値が出力されますので、これを元に予測クラスを確認できます。

問題なく、 254 の「パグ」を推論してくれました。

さて、転移学習では、この学習済みモデルによる画像の特徴量を抽出することを指します。

一般的には、最終レイヤーの全結合層の一つ手前のレイヤーの出力ベクトル値を特徴量として得ることが多いですが、用途やレイヤーごとの結果を見て選択したりと様々です。

先程のモデルインスタンスの __call__ の実行には引数 layers が指定でき、どのレイヤーのベクトル値を得るかを指定することができます。

指定にはレイヤーの名前を文字列の配列で指定できます。

モデルが出力できるレイヤーの名前は以下で確認ができます。

VGG16の場合は最終レイヤーの全結合層が fc8 なので、手前の fc7fc6 などを指定して、以下のように出力ベクトル値を得ることができます。

また、Chainerでは モデルインスタンスの extract 関数を使うと、自動でクラス分類の全結合層の一つ手前のベクトル値を得ることができます。

extract 関数では、内部で chainer.links.model.vision.vgg.prepare による前処理も実装されているため、画像データをそのまま渡して結果を得ることも可能です。

以上が、Chainerを使って、学習済みのVGG16による特徴量を得る方法となります。

ResNet152

続いて、ResNet(ここでは例としてResNet152)になりますが、他のモデルもVGG16の時と同様のインタフェースで、学習済みモデルの読み込み、実行ができるようになっています。

ResNet152の場合は、chainer.links.ResNet152Layers モデルを呼び出します。

ResNet50の場合は chainer.links.ResNet50Layers、ResNet101の場合は chainer.links.ResNet101Layers でそれぞれモデルを宣言できます。

ResNetに対する入力画像の前処理も同様に、chainer.links.model.vision.resnet.prepare 関数で実行できます。

モデルの推論で指定できるレイヤーは以下になります。

よって、 res5pool5 などのベクトル値を以下のように取得できます。

ResNetも同様に extract 関数で、クラス分類の全結合層の一つ手前のベクトル値が得られます。

GoogLeNet

GoogLeNetの場合も同様です。

モデルは chainer.links.GoogLeNet 関数で呼び出せます。

chainer.links.model.vision.googlenet.prepare 関数で前処理を施して、推論が可能です。

指定できるレイヤーは以下になります。

GoogLeNetの場合はネットワークがややこしく、評価時のロスを計算するレイヤーが複数あるのに対し、予測時に計算するレイヤーは loss3_fc のみになります。

どのレイヤーを指定するかは、やはり用途や結果を見てといった話になりますが、ラベルの予測レイヤーの手前のレイヤーだと、以下などになります。

同様に extract 関数で、クラス分類の手前のベクトル値が得られます。

学習済みの重みを使ってファインチューニングする

学習済みモデルは、畳み込みなどのレイヤーで画像の特徴をうまく捉えられています。

これを利用して、最後の方のクラス分類をするレイヤーを任意のレイヤーに切り替えて再学習させると、一からモデルを学習させるよりも高い精度で学習できる場合があります。

そのようにして、学習済みモデルのレイヤーに任意のレイヤーを付け加えて再学習させる方法をファインチューニングといいます。

これまでに記した通り、Chainerでは学習済みモデルのネットワークを便利に使える関数が揃っていますので、ファインチューニングも比較的楽に実装できます。

VGG16

学習済みのVGG16をベースに、ファインチューニングさせるモデルを構築する例を紹介します。

転移学習の時に記したように、 layers 引数でモデルから出力させるレイヤーを選択できることを利用し、学習済みのレイヤーから得られたベクトルから、タスクに対応する任意のレイヤーを追加してモデルを実装することができます。

問題にもよりますが、VGG16の場合は、pool5層、fc6層、fc7層などのいずれかの出力を取って、新しくfc層を1〜3層追加することが多いです。

例えば、画像を入力して fc7層までの値をとり、最後に out_size 数のクラスに分類するモデルを実装した場合は下記のようになります。

例えば、10クラスに置き換えた場合には、以下のようにして、10次元のベクトルを得られるように、レイヤーが追加されていることが確認できます。

また、ファインチューニングでは、学習でオプティマイザをセットする際に、学習済みのレイヤーの重みはあまり学習しないように、ハイパーパラメータを設定する場合が多いです。

これも問題によって指定の方法は様々ですが、個人的によくやっているのは、下記のように、学習済みのレイヤーのみ学習率を抑えることが多いです。

学習済みのレイヤーは学習させない(重みを完全に固定する)場合には、disable_update関数が使えます。

あとは普通に学習部分を実装すれば、ファインチューニングをさせることができます。

ResNet152

ResNetも同様に実装が可能です。

res5層、pool5層などのいずれかの出力を取って、新しくfc層を1〜3層追加します。

モデルインタンスの宣言後は、オプティマイザなどの設定は同じように実装できます。

GoogLeNet

GoogLeNetに関しては、上記でも述べましたが、出力までの推論がややこしく構成されています。

評価時は、loss1loss2loss3を使って計算させますが、予測はloss3(loss関数を通す前のベクトル)のみ利用します。

したがって、それぞれのlossに向かうfc層を変更して、同様に複数のlossを使って評価を行うモデルに変更する必要があると考えるのが自然です。

しかし、GoogLeNetの layers 引数に、loss1_fc1loss2_fc1が選択できないようになっています。(指定しても何も返ってこない)

したがって、現状では、同様のネットワークを自前で作成する必要があるみたいです。



 

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です