勾配降下法に慣性項を追加するMomentumをPythonで実装

  • このエントリーをはてなブックマークに追加

機械学習でよく使われる関数の最小値を探す最適化には様々な手法があります。ここでは勾配降下法を少し改善したMomentum(モメンタム)の式とPythonコードを紹介します。

こんにちは。wat(@watlablog)です。本記事では勾配降下法の派生手法であるMomentumを学習していきいます

Momentumの概要

勾配降下法の更新式

勾配降下法について不明点がある方は、「Pythonで1変数と2変数関数の勾配降下法を実装してみた」で1変数の式からできるだけわかりやすく説明しましたので、まずはそちらをご確認下さい。

ここでは先ほどの記事で紹介している標準的な勾配降下法の更新式を式(1)に示します。

$$\mathbf{x}_{i+1}=\mathbf{x}_{i}-\eta \nabla f (1)$$

この更新式は学習率\(\eta\)が勾配にかかっているだけでとてもシンプルですが、学習の進捗に応じて単調な変化しかしません。

そして一度勾配がフラットに近くなると急激に変化量が少なくなるので、計算時間がかかる割にはちょっとした局所最適解にもはまりやすいというデメリットがあります。

Momentumは更新式を一部変更することでこの問題を改善します。

前回の更新量を利用して慣性項を追加する

Momentumの更新式は式(2)です。

$$\mathbf{x}_{i+1}=\mathbf{x}_{i}-\eta \nabla f + \alpha \Delta \mathbf{w} (2)$$

この式は先ほどの基礎式の右辺最後に係数\(\alpha\)と前回の更新量である\(\Delta \mathbf{w}\)の積を加えている所が違いになります。

この\(\alpha \Delta \mathbf{w}\)は変化量に対して係数がかかっているので、運動方程式で言う所の慣性項と同等の意味を持ちます。

慣性項とは、ある質点に外力が働かなくなっても動き続けようという性質のことで、ニュートンの運動法則の1つです。

言葉で説明するよりも動画を見た方が理解が早いと思いますので、以下のGIF動画をご覧下さい。

GDとMomentumの比較動画

この動画はシンプルな勾配降下法と今回コーディングするMomentumの比較です。降下する関数は式(3)です。

$$z=\frac{1}{4}x^{2}+y^{2} (3)$$

この関数は\(x\)軸方向も緩やかなカーブを描いているので、表示されている平面の中心が最小値です。

赤い点がMomentumとなりますが、GD(勾配降下法)よりも速く最小値に到達していることがわかります。

また、\(y=0\)に到達した瞬間はオーバーランしている慣性項の特徴も確認されました。

今回はこのMomentumをPythonでコーディングします。

Momentumのメリットとデメリット

Momentumは単純なGDが持つ急激な変化(緩やかな勾配ですぐに変化量が小さくなること)をできるだけしないようにハイパーパラメータを調整することができます。

滑らかな変化量とすることで計算の効率を向上させたり、ちょっとしたくぼみを乗り越えたり、より最適な解へ到達しやすくなるメリットがあります。

しかし、\(\eta\)の他に\(\alpha\)と2つのハイパーパラメータに増えたことで、その調整が容易ではなくなるというデメリットも持ちます。

現代では日々最適化問題を効率良く解くためのテクニックが研究されていますが、ここでは基礎からの理解としてまずはこのMomentumをPythonで書いていきたいと思います。

Momentumで最小値を探すPythonコード

全コード

以下のコードはGDとMomentumを比較するためのPythonコードです。それぞれの手法で更新式以外はほとんど同じであり、書き方も「Pythonで1変数と2変数関数の勾配降下法を実装してみた」で紹介した時とほぼ変わっていません。

今回、反復計算中の前回のベクトル更新量としてupdateという変数を用いましたが、このupdateにどこまで含めるのが正解なのかの解釈ができていません。もしかしたら正式のものとはちょっと違う可能性があります。

僕は以下の書籍P164の式を読んで今回のコードを書きました。時間があればもっとゆっくり読んだり、他の文献を見たりしてみたいと思います。

Pythonで学ぶニューラルネットワークとバックプロパゲーション

実行結果

以下が実行結果です。同じ学習率でも慣性項がある方が若干早く最適解に到達していますね。

Momentumにおける最小値探査の実行結果

おまけ:軌跡を画像にする

ちなみに、先ほどの動画はPythonのmatplotlibを静止画にしてフリーソフトでGIFにしただけです。今回のコードで静止画を保存する方法は以下のコードに載せておきます。

上記コードはalphaの値を高めに設定しているため、下図のような挙動になります。やりすぎると慎重派のGDの方が優秀であるといった結果にもなりますね。

やりすぎたMomentum

まとめ

今回は基礎的な勾配降下法を少し改良したMomentumの式を学びました。

Momentumは運動量というその名の通り、まるで慣性が働いているかのような最適化手法であることがわかりました。

Pythonコードもこれまでの記述方法と全く同じなので特に難しさは感じませんでしたが、最適化手法というのは非常に奥が深いとも感じました。

様々な最適化手法がありますが、万能な方法は今の所なく、今回のMomentumもハイパーパラメータの調整を必要とします。

ちょっとした式の違いでこれだけ挙動が変わるのは大変面白いですね!Twitterでも関連情報をつぶやいているので、wat(@watlablog)のフォローお待ちしています!

  • このエントリーをはてなブックマークに追加

SNSでもご購読できます。

コメント

コメントを残す

*