【李宏毅老師2021系列】類神經網路訓練不起來怎麼辦(三):Batch

追客 Fredrick Hong
8 min readApr 30, 2022

這個系列是在觀看李宏毅老師 2021系列的筆記,希望能用更濃縮的方式將內容整理下來。

此篇講解的概念是在訓練深度學習網路中,可以調整的一種參數 (hyperparameter):batch size,也是我覺得實務上很有價值的一個觀念。

若對於實驗的過程沒有興趣,可以直接跳到總結看 large batch 跟 small batch 在各個面向比較的結論。

此篇筆記來自課程的前半段:【機器學習2021】類神經網路訓練不起來怎麼辦 (二): 批次 (batch) 與動量 (momentum) — YouTube

同樣因為 notation 的關係,如果想看比較美觀的版本也可以直接看 notion 的筆記版本:https://quixotic-revolve-92a.notion.site/Tips-for-training-Batch-and-Momentum-cfcbead0a7a040dab13a5802f3ed2e05

Review: Optimization with Batch

開頭老師先介紹 batch 實際上是怎麼樣計算 gradient,並沒有先說明為什麼要使用 batch 來訓練 network。

  • 我們實際上在算 gradient 的時候,並不是對所有 data 算出來的 loss 做微分,而是將整個 data 拆分成一組一組的 B (batch or mini batch),然後對每一組的 data 重複做以下動作。
* 圖中 θ = theta、g = gradient、L = loss* pick initial values theta⁰:選定特定一組參數 theta⁰計算。* 計算 g⁰:把 theta⁰ 這組參數放入 L 中得出 L¹,對 L¹ 做微分,就會得到 g⁰。* 計算 g¹ , g², …。
  • 1 epoch:把所有的 batch 都看過一遍
  • 實際上在 training 時,每一個 epoch 我們會做 shuffle (有不同方式),助教在作業做的shuffle,是讓每個 epoch 的每組 batch 的 data 都不一樣。
圖1 — 截圖自李宏毅老師課程

Why should we use Batch?

Small Batch v.s. Large Batch

Error Surface

  • 以 Error Surface 來觀察 training 時 L 的變化 (藍色 loss 小,紅色 loss 大)
  • 假設 training data 有20筆
  • 左側 Batch size = Full Batch,因此 model 要看完20筆 data 才會更新參數。
  • 右側 Batch size = 1,因此 model 每看完1筆 data 就會更新一次參數,在一個 epoch 中總共會更新20次,更新速度快但雜訊較多。

小結 (一):Large Batch 的技能cd長,但威力較大,Small Batch 的技能cd短,但雜訊較多。

圖2— 截圖自李宏毅老師課程

But Actually …

打臉剛剛的小結,平行運算時,Batch size 較大不一定需要比較長的時間去計算 gradient。

  • 圖3 用 MNIST 的手寫圖片的任務來測試
  • 當 Batch size 小於1000時,所需要的時間幾乎是一樣的,但超過1000時,還是會有它的限制。(此情況是使用 Tesla V100 GPU 的測試結果)
圖3 — 截圖自李宏毅老師課程
  • 圖4 左側講的是更新一次參數所需要花費的時間,右側是跑完一個 epoch 所需要的時間,可以看到 batch size 越大,跑一次 epoch 的時間越少。

小結 (二):在使用GPU的情況下,大的 batch size 速度快而且學習得夠穩定,完勝小的 batch size。打臉小結 (一)

圖4 — 截圖自李宏毅老師課程

Small Batch Advantage

課程看到看到這邊會覺得,只要硬體規格足夠,batch size 調得越大越好,但這時老師又接著介紹 small batch 有什麼優勢。

優勢 (一):Accuracy 表現較好 (參考圖5)

  • 如前面所提到的,使用小的 batch size 時訓練的 $L$ 看起來雜訊比較多,然而這些雜訊竟然可以提升 model 預測的準度。(MNIST、CIFAR-10 都是圖片辨識的任務)
  • 因為是在同一種 model 下測試不同 batch size 的結果,因此不會是 model bias 的問題,這是 optimization 的問題。
圖5 — 截圖自李宏毅老師課程

解釋優勢 (一):為什麼 small batch size 訓練出來的 model 準度比較高呢?

一種解釋的方式 (參考圖6)

  • Full Batch 的 training,會根據 loss 透過 gradient descent 優化,當我們遇到 critical point 時,gradient = 0,參數不再更新,若不透過 Hessian 或其他方式,optimization 就會卡住。
  • Small Batch 的 training,同樣也會卡在 critical point,但因為 small batch 有多組不同的 loss,因此訓練出來的準度反而比較高。
圖6 — 截圖自李宏毅老師課程

優勢 (二):Small Batch 在 testing data 表現更好 (參考圖7)

  • 這篇論文的實驗結果是,不管是 small batch 或 large batch 的 training accuracy 都表現得一樣好,但是在 small batch 在 testing accuracy 的表現優於 large batch。
  • 也就是說 large batch 會有 overfitting 的問題。
圖7 — 截圖自李宏毅老師課程

解釋優勢 (二):為什麼 small batch 比較不容易 overfitting?(參考圖8)

Flat Minima (good minima) and Sharp Minima (bad minima)

  • 一個 training 的 loss 可能會有很多 local minima,有分好或壞的,平坦的比較好,陡峭的比較不好。
  • 假設今天的 testing data 的 loss 是 training data 的 loss 往右偏一點,那麼平坦的 minima 在 testing data 的表現會跟 training data 差不多,而陡峭的表現則會差距很大。(實際上的testing data 可能因為 label 的分佈跟 training data 不同,導致兩者的 loss 也不相同)

一種解釋方式(並非每個人都認同,尚待研究)

  • Small Batch 比較容易卡在 Flat Minima,Large Batch 容易卡在 Sharp Minima
  • 直覺的想法是,small batch 參數更新的方向很多又很快,因此當 small batch 遇到小的峽谷 (sharp) ,可能會一下子就跳出去了,無法找到 sharp minima,所以通常會卡在 flat minima,而 large batch 的情況就是相反。
圖8 — 截圖自李宏毅老師課程

總結

  • 更新一次參數的速度(沒有平行運算):small batch 勝
  • 更新一次參數的速度(平行運算):相同
  • 訓練一次 epoch 所需要的時間:large batch 勝
  • Gradient:small batch (雜訊多)、large batch (穩定)
  • Optimization:small batch 勝
  • 綜合評價:small batch 勝

結語:最後綜合評價以 small batch 勝出,因此也回答了 why should we use batch 的問題,因為若不使用 batch 就等於是用 full batch (large batch) 來訓練 network。

圖9 — 截圖自李宏毅老師課程

發想:魚與熊掌可否兼得?

有些人想要實現 Large Batch 的運算速度 + gradient 的穩定度 + Small Batch 的 Optimization 的訓練方式。

以上就是這堂課程關於 Momentum 這個部分的筆記了,如果喜歡我的筆記,歡迎給個clap或留下留言!

--

--

追客 Fredrick Hong

畢業後就到在數位廣告業打滾,之前是廣告優化師,目前則是在數據團隊任職。