アニメーションでDeepSpeed (ZeRO1)の仕組みを完全に理解する

概要

  • DeepSpeed stage1(Optimizer Stateの効率化)についてコードとアニメーションでわかりやすく説明してくれている記事
  • DeepSpeed ZeRO1の流れ
      1. optimizerの初期化。optimizer stateに分割したパラメータだけを保持する。
      1. forward()を実行してlossを計算する
      1. backward()を実行して勾配を計算する
      1. 2で計算した勾配を全体で平均化する(reduce all)
      1. step()でOptimizer Stateを計算し、モデルの重みを更新する
      1. 部分ごとに計算したモデルの重みが全体に行き渡るようにbroadcastする