PyTorchでニューラルネットワーク、CNNを実装してみた

今回は、PyTorchでニューラルネットワーク、畳み込みニューラルネットワークの実装について記します。

以前にChainerの実装をまとめたときのものと同じタスクを実装してみて、比較しやすいようにしてみました。

Chainerでニューラルネットワーク、RNN、CNNを実装してみた
RNNの実装の勉強もしました。今回は整理と備忘録も込めて、Chainerでニューラルネットワーク、リカレントニューラルネットワーク、畳み込みニューラルネットワークの実装について記します。ニューラルネットワーク(Neural network; NN)順伝播型のニューラルネット...

PyTorch

PyTorch: http://pytorch.org/

PyTorchはFacebookが開発する深層学習に特化したライブラリです。

Pythonから実行できます。

インストールは公式ページを見るとわかりやすいと思いますが、OSや環境などをポチポチ選択していけばインストールコマンドが参照できますので、それを実行するだけです。

2017年始め頃に登場したばかりですが、瞬く間にTensorFlorやKerasに続く人気の深層学習ライブラリとなっているようです。(自分の周りでは使っている人がいないですが…)

また、PyTorchは、Preferred NetworkのChainerから影響を受けているようで、Chainerと同様、計算時に動的にグラフを構築する(Define-by-Run)ライブラリです。

書き方もTensorFlowやKerasなどよりもChainer寄り、というかむしろそっくりな書き方をするのが特徴です。

PyTorchによるニューラルネットワークの実装

順伝播型ニューラルネットワークを実装してみます。

Chainerのときと同様に、簡単な問題、排他的論理和を解くモデルを実装しました。

GitHub: https://github.com/Gin04gh/samples_py/blob/master/NeuralNetwork_PyTorch.ipynb

かなりChainerに似ていますね。

データはDataLoaderというクラスで読み込むようです。

また、numpy配列のままではなく、torchの配列にする必要があるみたいですが、ソースからわかるように、numpy配列をfrom_numpyメソッドでtorchの配列に変換できます。

PyTorchによる再帰的ニューラルネットワーク(RNN)の実装

RNNも最初は実装してみようかと思ったのですが、今回は省略します。

どうやらチュートリアルを見てみても、LSTMも今のところ自前で作らなければならなそうでした。

しばらく様子見していれば実装されそうな気もしますし、とりあえず今回はスルーすることにしました。

追記

と思ったけど、nn.LSTMCellとかありますね。

ちょっと勉強をしてみます。

PyTorchによる畳み込みニューラルネットワーク(CNN)の実装

続いて畳み込みニューラルネットワークになります。

こちらも画像分類ではチュートリアルとしておなじみのMNISTの画像分類問題を実装してみました。

GitHub: https://github.com/Gin04gh/samples_py/blob/master/ConvolutionalNeuralNetwork_PyTorch.ipynb

以上、PyTorchでニューラルネットワークと畳み込みニューラルネットワークを実装してみました。

Chianerと一緒とまではいきませんが、やはり基本的にChainerを使っている身からすると、学習コストは低い方なのかなと思います。

あと、個人的に使っていて思ったのは、Chainerよりも学習速度が早く感じました。

Chainerでニューラルネットワーク、RNN、CNNを実装してみた
RNNの実装の勉強もしました。今回は整理と備忘録も込めて、Chainerでニューラルネットワーク、リカレントニューラルネットワーク、畳み込みニューラルネットワークの実装について記します。ニューラルネットワーク(Neural network; NN)順伝播型のニューラルネット...

上記のときと実行時間を比べてみても、やっぱりだいぶ早いのではないかと思います。

基本的に深層学習はGPUで動作させるものなので、CPUだけの結果だと何とも言えませんけど…。

 

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です