2024-01-25 機械学習勉強会
ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
- Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He
- ‣
モチベーション
- 大規模な深層学習モデルは大きな精度向上をもたらすが、数十億から数兆のパラメータを学習させることは困難である
- データ並列やモデル並列のような既存のソリューションの課題
- データ並列 (Data Parallelism: DP)
- GPUあたりのメモリを削減できないため大規模なモデルではメモリ不足となる
- モデル並列 (Model Parallelism: MP)
- 通信のオーバーヘッドが大きい
- Zero Redundancy Optimizer(ZeRO)は学習可能なモデルサイズを増加しつつ高速な学習を実現する
ZeRO-DP / ZeRO-R
学習時のメモリ消費は大きく2つあり、これらの2つのメモリ使用量の改善のためにZeRO-DP, ZeRO-Rを開発
- Model State Memory
- optimizer states (e.g. momentum and variances in Adam)
- gradients
- parameters
- Residual State Memory
- activation
- temporary buffers
- unusable fragmented memory
Optimizing Model State Memory : ZeRO-DP
- Model Stateはしばしば学習中に最も多くのメモリを消費するが、DPやMPのような既存のアプローチは満足のいく解決策を提供しない
- DPはModel State全体をすべてのデータ並列処理に複製するため、冗長なメモリ消費となりメモリ効率が悪い
- MPは高いメモリ効率を得るためにModel Stateを分割するが、通信のオーバーヘッドを増加させ、スケーラビリティが低下する
- DPもMPも、学習プロセス全体にわたって必要なModel Stateをすべて保持するが、常にすべてが必要なわけではない
- 例えば、各層に対応するパラメータはその層のforwardとbackwardの間だけ必要である
- ZeRO-DPは、Model Stateを複製する代わりに分割することで、データ並列プロセス間のメモリ状態の冗長性を排除している
- ZeRO-DPには、optimizer state, gradients, parametersの分割に対応する3つの最適化ステージがある (Figure 1)
- Optimizer State Partitioning
- 並列数が のときoptimizer statesを 個の等しいパーティションに分ける
- 番目のデータ並列プロセスは 番目のパーティションに対応するoptimizer statesのみを更新する ( のoptimizer statesだけを保存すればよい)
- トレーニングステップの終了時にデータ並列プロセス全体でall-gatherする
- Add Gradient Partitioning
- optimizer statesと同様に、 番目のデータ並列プロセスは 番目のパーティションに対応するパラメータの勾配のみを保存・更新する
- backward時に別のデータ並列プロセスで必要となる勾配を転送し不要な勾配を削除する (reduce-scatter)
- Add Parameter Partitioning
- 同様に 番目のデータ並列プロセスは 番目のパーティションに対応するパラメータのみを保存・更新する
- forward, backward時にパーティション外のパラメータが必要な場合はデータ並列プロセスから受け取る
- これは一見すると大きな通信オーバヘッドが発生するように見えるが、このアプローチはベースラインDPシステムの通信量の1.5倍に増加するだけで に比例したメモリ削減が可能となる
- ZeRO-DPは、メモリの冗長性を排除し、クラスタの総メモリ容量をフルに利用できるようにする
- 3つのステージをすべて有効にすると、ZeROはわずか1024個のNVIDIA GPUで1兆パラメータモデルをトレーニングできる。
- ZeRO-DPでは、とを使用して追加の通信は発生せず、 を使用すると最大1.5倍の通信が発生する
Optimizing Residual State Memory : ZeRO-R
- ZeRO-DPは、Model Stateのメモリ効率を向上させたが、activation, temporary buffers, unusable fragmented memoryによって消費される残りのメモリがボトルネックとなりうる
- ZeRO-R は以下の3つの要素によって消費される残存メモリをそれぞれ最適化する
- activation
- activation checkpointingが役立つが大規模モデルでは不十分
- 例えば、1,000億個のパラメータを持つGPTのようなモデルの場合、activation checkpointingを使用しても、バッチサイズ32で約60GBのメモリを必要とする
- MPはModel stateを分割するが、しばしばactivationの複製が必要になる
- 例えば、線形層のパラメータを垂直に分割し、2つのGPUで並列に計算する場合、各GPUはその分割を計算するためにactivation全体が必要
- ZeRO は、activationを分割することで、この冗長性を排除し、activationが計算で使用される直前に、一度に 1 つのactivation layerのみを複製して実体化する
- その後各データ並列プロセスに分割され、backwardで再び必要となる際にはall-gatherする
- この最適化を と呼び、activation checkpointingによりパーティション化されたチェックポイントのみを保存する
- さらに非常に大きなモデルでメモリが非常に限られている場合、CPUにオフロードすることで追加の通信コストでactivation memoryのオーバーヘッドをほぼゼロにすることができる ()
- temporary buffers
- 中間結果を格納するために使用される一時的なバッファは、大規模モデルの場合、非自明な量のメモリを消費する
- 勾配の all-reduce や ノルムの計算などの演算はスループットを向上させるためにすべての勾配を単一の平坦化されたバッファに融合する傾向がある
- 例えば、1.5Bのパラメータを持つモデルの場合、平坦化された32 bitのバッファは6GBのメモリを必要とする
- fused bufferのメモリオーバーヘッ ドはモデルサイズに比例し、阻害要因になる可能性がある
- ZeROはこの問題に対処するため、モデルが大きくなりすぎた場合は、単純にパフォーマンス効率の高い一定サイズのfused bufferを使用することでメモリと計算効率のバランスをとる
- unusable fragmented memory
- 学習時におけるメモリの断片化は、activation checkpointingと勾配計算の結果発生する
- activation checkpointingでは、選択されたactivationのみがbackwardのために保存され、ほとんどのactivationは破棄されます
- これにより、短寿命メモリ(破棄されたactivation)と長寿命メモリ(チェックポイントされたactivation)が混在しメモリの断片化につながる
- 同様にbackwardの間、勾配は長寿命であり、勾配の計算に必要なその他のバッファは短寿命であり、メモリの断片化を引き起こす
- メモリの断片化は2つの問題を引き起こす
- 利用可能なメモリが十分にある場合でも、連続したメモリが不足することによるOOM
- 連続したメモリを探すのにメモリアロケータが多大な時間を費やすため、効率が悪くなる
- ZeROは、activation checkpointingと勾配のために連続したメモリチャンクを事前に割り当て、それらが生成されると事前に割り当てられたメモリにコピーすることでデフラグを行う