読者です 読者をやめる 読者になる 読者になる

もちもちしている

おらんなの気まぐれブログ

Dropoutの実装と重みの正則化

この記事はMachine Learning Advent Calendar 2013 3日目の記事です.

はじめに

ニューラルネットワークの汎化性能を向上させるDropoutは, Deep Learningを実装する上で必須の技術だと思います. 本日はDropoutとその実装方法について説明させていただきます.

Dropoutとは

ニューラルネットは複雑なモデルであるため過学習に陥りやすいです. これを回避するためにはL2ノルムで値の増加を防いだり, L1ノルムでスパースにしたりするのが一般的です.

しかし正則化でもニューラルネットのような複雑なモデルに適切に制約を加えるのは困難です.

そこでDropoutの考え方です. Dropoutは各訓練データに対して中間素子をランダムに50%(または任意の確率で)無視しながら学習を進めます. 推定時は素子の出力を半分にします.

f:id:olanleed:20131130221427p:plain

なぜこれだけで汎化性能が向上するのかというと, "アンサンブル効果"で説明がつきます.

アンサンブル効果

アンサンブル学習

アンサンブル学習とは個々に学習した識別器を複数用意し, 出力の平均を取るなど, それらをまとめあげて一つの識別器とする方法です. 一人で考えるよりも複数人で議論したほうが良い結果を得られるのと同じでしょうか.

バギング

ブートストラップによってランダムに生成した異なる訓練データから大量のモデルを学習し, 平均化することで推定を安定させる方法です. これを取り入れた学習器にRandam Forestが挙げられます.

しかしニューラルネットの学習には時間がかかるので, 実際にバギングを行うのは現実的ではありません. ですがDropoutを行うことにより, 訓練データ毎にモデルが少し異なるためバギングと同様の効果が得られます.

推定時は全素子を使いますが, 素子の出力を半分にすることでモデルの平均を取ることに相当します.

Dropoutの実装

いろいろなアプローチがあると思いますが, ここではドロップアウトしたい隠れ層の素子をマスクする方法を紹介します.

Forward Propagate

ある隠れ層

 \begin{eqnarray*}h=f(a)=f(Wx+b)\end{eqnarray*}

があるとします. この式にドロップアウトマスクをかけてあげるだけです.

h_m = h \circ m

mは通過させたい出力を1, 阻止したい出力を0と表現したベクトルです.

Back Propagate

Back Propagate時も素子をなかったことにしないといけません.


\delta=({}^t\!W_p\delta_p)\circ f^{\prime}(a) \circ m

もし活性化関数にシグモイド関数を使っていた場合は以下のように式を変形することができます.


 f^{\prime}(a)=h_m\circ(1-h_m)

これでDropoutの実装は終わりです!

重みの正則化

よりよい重みwを速く求める, そして更に汎化性能を向上させるには重みの正則化をしたほうがよいとされています.

Momentum

モーメント法と呼ばれる勾配法の高速化手法です. 学習率が低いとなかなか収束しないため, 重みを調整してあげるのがmomentumの役割です.
勾配法によるBackPropagationの重み更新式は以下のように与えられています.


\Delta w = -\epsilon \frac{\partial E}{\partial w}

\epsilon : 学習率

これにmomentum係数を導入して式を以下のように修正します.


\Delta w = -\epsilon \frac{\partial E}{\partial w} + \mu \Delta w

\mu : momentum係数

Weight decay

汎化性能向上のための手法の一つで, Weight decayを導入すると重みの爆発が抑制され, 過学習を防いでくれます.


\Delta w = -\epsilon \frac{\partial E}{\partial w} -\epsilon \lambda \Delta w

\lambda : Weight decay係数

他にも中間素子に入力する重みのL2ノルムが閾値を越えたら, 閾値と同じ値となるように線形スケーリングを行う方法もあります[1].

Dropoutの欠点

Dropoutの欠点として, たくさん更新しないと学習が進みません(計算時間が増える).
この欠点を克服した"Fast Dropout"がありますが, これはまた別の機会に紹介したいと思います.

おわりに

BackPropagationにDropoutを適用したソースコードを公開します. 間違いがありましたらご連絡ください.

https://github.com/olanleed/BackPropagation