【李宏毅老師2021系列】類神經網路訓練不起來怎麼辦(一):Local Minima and Saddle Point
這個系列是在觀看李宏毅老師2021系列的筆記,希望能用更濃縮的方式將內容整理下來。
這篇如標題所示,主要先講解概念定義,再透過公式推導和實際例子來解釋 Local Minima 與 Saddle Point 這兩個概念。
此篇筆記來自此課程:【機器學習2021】類神經網路訓練不起來怎麼辦 (一): 局部最小值 (local minima) 與鞍點 (saddle point) — YouTube
在開始之前,想先跟大家說聲抱歉,因為我的筆記是紀錄在 notion,可以直接使用 Latex 的 notation(好用又美觀)。
但因為我還沒找到可以在 medium 直接使用 notation 的方式,因此直接複製過來,很多數學符號和變數可能會不太好懂。
如果想看比較美觀的版本也可以直接看 notion 的筆記版本:https://quixotic-revolve-92a.notion.site/Local-Minima-and-Saddle-Point-29367f7bd3f3414186477729106a0647
When gradient is small…
Optimization Fails because…
* 參數對loss的微分等於0 (gradient=0),因此參數不再更新,我們稱之為critical point。
* critical point 有兩種情況。
* local minima:局部loss的最小值。
* saddle point:非local minima也不是local maxima,但gradient為0。
* 處於local minima代表loss已經位於相對低點,但處於saddle point則代表我們的loss仍有很多優化空間。
How to find out we stuck at local minima or saddle point?
Tayler Series Approximation
* 直接講結論就是,我們無法知道 L(θ) 確切的樣子,但我們可以透過 L(θ’) 大約算出來。 (公式如下左圖)
* L(θ’) :給 L(θ) 訂某一組參數,而 L(θ’) 附近的參數是有辦法被寫出來的,就是Tayler Series Approximation。
* 當g (gradient)等於0時,就只剩下公式右半邊的數值,我們就可以透過這個 L(θ’) 來畫出 error surface,藉此了解我們是處於 local minima 還是 saddle point,這種方式稱之為 Hessian。
Hessian
- 判斷
* [(θ-θ’)^T * H(θ-θ’)]/2,縮寫成v^T*Hv
* 對於任何 v 而言,只要 v^T*Hv > 0,那麼 L(\theta) > L(\theta’),代表 L(\theta’)是附近最低點 (local minma)。
* 若 L(\theta) < L(\theta’),那就是local maxima。
* 有時大於0,有時小於0,那就是處於saddle point。
* 我們不可能帶入所有有可能的 $v$ ,透過線性代數理論有更簡單的方式去判斷。
* local minima:v^THv,Hessian是positive definite的矩陣,這種矩陣的特性是所有eigen跟values都是正的,那麼 v^THv > 0 就會成立。
* local maxima:反之亦然。
saddle point:eigen 跟 values 有正有負。
- 結論
* 我們會算出一個 H 的矩陣叫做 hessian,裡面會有 eigen 跟 values,只需要去判斷它們是正或負,就可以知道目前的 function 處於哪一種 critical point。
Example
圖片解釋
* 假設我們有個史上最爛的function, $y = w1w2x$,和史上最爛的training data,只有一筆data (input=1, label=1)。
* 當我們爆量蒐集各種 $w1 w2$,所畫出來的 error surface 就如下圖。
* saddle point: 非 local minima,也非 local maxima。
* 因此最中間的點就是 saddle point,因為往左上或右下 loss 會增加,往左下或右上 loss 會減少。
* local minima: 山溝裡面有好幾個點代表局部最小值。
公式推導
* L = (\hat y — w1w2)² = (1 — w1w2)²
* 接著把 loss function 的 gradient 求出來(loss function 分別對 w1, w2 做微分),gradient 用 g 表示。
* Critical point 就是當 g = 0,而我們可以發現 w1=0, w2=0 時, g =0 成立。
* 而要了解這個 critical point 是 saddle point 或 local minima,則需要看 hessian (H)。
* H 是一個矩陣,而裡面的值就是收集了 L 的二次微分,接著把 w1=0, w2=0 帶入, H 算出來就會是 [[0, -2], [-2, 0]],這四個值有正有負,因此代表這個 critical point 就是 saddle point。
Don’t afraid of saddle point
- 如果今天 function 卡在的 critical point 是 saddle point,其實 H 不但顯示我們卡在 saddle point,更指引了參數更新的方向。
回顧在 critical point
1. $L(\theta)\approx L(\theta’)+ \frac{1}{2}(\theta-\theta’)^TH(\theta-\theta’)$,我們將其中的 $\frac{1}{2}(\theta-\theta’)^TH(\theta-\theta’)$ 表示為 $v^THv$。2. 當 v^THv 有正有負就代表我們處在 saddle point,而且 H 會告訴我們參數更新的方向。
參數如何更新 — 1
- 假設一
* u 是 H 的 eigen vector。
* \lambda 是 u 的 eigen value。
- 如果我們將 v^THv 中的 v 替換成 u,就可以推導出
* u^THu = u^T(\lambda u) = \lambda ||u||²
* 因此可以說 \frac{1}{2}(\theta-\theta’)^TH(\theta-\theta’) 就是 \lambda ||u||²
* 當 \lambda < 0 時, u^THu 就會 < 0,而 L(\theta) < L(\theta’) 就會成立。
- 將 v 替換成 u
* 回顧上面 \frac{1}{2}(\theta-\theta’)^TH(\theta-\theta’) = v^THv
* 我們又將 v 替換成 u,就代表 \theta — \theta’ = u,等於 \theta = \theta’ + u。
* 把本來參數在 \theta’ 的位置加上 u,沿著 u 的方向更新,就會讓 L 變小。
參數如何更新 — 2
- 同樣是剛剛的例子,我們處在中間黑點的 critical point。
* 這時候 gradient = 0,因此需要透過 Hessian 來判斷接下來參數如何更新。
* 此時的 $w1=0, w2=0, \lambda_1=2, \lambda_2=-2$。( $\lambda$ 就是 eigen value)
* 此時的 $H$ 是一個 $[[0, -2],[-2, 0]]$ 的矩陣。( $H$ = Hessian )
- 上面公式的推導,說明了當 $\lambda < 0$ 時,參數只要沿著 $u$ 的方向更新, $L$ 就會變小
* 因此當 $\lambda_2 = -2$, $\lambda_2$ 有一個 eigenvector $u = [[1],[1]]$
* 這時候只需要跟著 $u$ 這個矩陣的方向更新參數,就會減少 $L$ 了。
小結
從這個角度來看,training 時卡在 saddle point 並不可怕。 然而實際上,並不會把 Hessian 算出來,因為「二次微分」跟「要找出 eigen value」的計算量太大了,後面還會講解其他計算量較小的解法。
Saddle Point v.s. Local Minima
魔法師狄奧倫娜:如何從封閉的石棺中偷出聖杯
- 從3維的空間石棺是封閉的,但在更高維度的空間中並不是封閉的。
當我們在視覺化 Training 的 error surface 時
- 從2維看,以為是 local minima,但從3維看,其實是 saddle point。
- 有沒有可能我們從3維看,以為是 local minima,從無法視覺化的更高維度看,其實是 saddle point。(右半邊)
- 因此當參數量很多時,會不會根本沒有 local minima。
實驗支持假設
- 假設
* 當 function 參數量很多時,沒有 local minima,其實實務上都是卡在 saddle point。
- 實驗
* 訓練很多 functions,圖中每一點都是一個 function。
* 縱軸代表 training loss,橫軸代表 minimum ratio。
* minimum ratio = Number of Positive Eigen values / Number of Eigen values。
* 當 Eigen values 都是正值時,critical point 才會是 local minima,因此 ratio =1 才是 local minima。
- 結果
* minimum ratio 頂多超過0.5一些,越往右邊越像 local minima,但仍然不是真正的 local minima,因此假設成立。
下集預告
接下來會講解當我們卡在 local minima、saddle point、或者接近 saddle point 的 plateau 時,實務上的解法。
懶人包
- Critical Point:當我們在訓練的過程中,當 gradient = 0 時,參數不再更新的狀況。
* 有兩種狀況 gradient = 0
* Local Minima:局部 loss 的最小值。
* Saddle Point:非 local minima 也不是 local maxima。
* 處於local minima代表loss已經位於相對低點,但處於saddle point則代表我們的loss仍有很多優化空間。
- 怎麼判斷我們是在 Local Minima 還是 Saddle Point
* 我們可以透過 Hessian 判斷。
* Hessian 是一個矩陣,其中包含 eigen values。
- 處在 Saddle Point 怎麼辦
* 除了可以透過 Hessian 判斷我們處於哪種 Critical Point 之外,也能透過 Hessian 判斷參數更新的方向。
* 但實務上因為 Hessian 的計算量太大,因此會用其他方法來逃離 Saddle Point。
- Local Minima 跟 Saddle Point 在不同維度空間中的誤解
* 當我們處在更高維度觀察 Local Minima時,看起來會像是 Saddle Point。
以上就是這堂課程的筆記了,如果喜歡我的筆記,歡迎給個clap或留下留言!