2024-02-15 機械学習勉強会
LightRNN: Memory and Computation-Efficient Recurrent Neural Networks
Xiang Li, Tao Qin, Jian Yang, Tie-Yan Liu
NIPS2016
- 昔軽く読んで賢いなぁという印象を持っていた論文。最近Xで言及している投稿を見て久々に思い出してピックアップしました。
- 大規模言語モデルが盛り上がり、embeddiingがとても一般的な概念になった今、知っておいたもいい手法だなと考えています。RNNがベースとなっていますが、RNNに限らず応用可能な話であると理解しています。
サマリー
- RNNよりもモデルサイズと計算コストを大幅に削減しつつ、性能を維持したLightRNNを提案。
- ひとつの単語の分散表現を獲得するために、複数の単語で共有された2つのembeddingを利用する、2-Component (2C) shared embedding という仕組みがキーアイデア
- 実験において、モデルサイズを最大1/100に抑えた上で学習時間を1/2に高速化しながら、性能(perplexity)を同等以上に保った。
背景
- 【前提】2016年に発表された論文なので、マシンリソースの前提などは現在とは大きく異なります。
- 系列データを扱えるRNN(Recurrent Neural Network)が、機械翻訳や言語モデリングなど自然言語処理の多くのタスクで利用されるようになった。
- 一方で、ボキャブラリサイズが大きい場合、入力と出力の埋め込み行列が非常に大きくなりパラメタ数が非常に大きくなる。
- たとえばボキャブラリが10M(10^7)の場合
- ClueWebを想定している。最新版だと ClueWeb22で100億のWebページを参照。
- https://paperswithcode.com/dataset/clueweb22
- 埋め込み表現を1024次元とし、浮動小数点で表現するなら
- 入力行列は (10^7個の単語) x (1024次元) x (4バイト) = 40GB
- 出力行列も同様なので、入出力の行列だけで80GB以上になる。
- 当時のマシンリソース状況を鑑みると、学習および推論に際して十分なGPUメモリを持った環境を用意することはなかなか難しいことが想像できる。
- メモリの問題もあるし、計算コストの問題もある。この大きさでソフトマックス計算するとつらい。
- 当時の環境での(?)シングルGPUでClueWebを用いてRNNを学習させると数十年計算とのこと。
- 既存手法はいろいろあるが、モデルサイズと計算コストの大幅な削減を同時には達成していない
- 計算コスト
- 基本、ソフトマックスの計算を抑える手法
- hierarchical softmax
- importance samplingやblack outなどのサンプリングベースの近似計算とか。
- モデルサイズ
- 出力層の行列を小さくするための differentiated softmaxやrecurrent projection
- 入力層の行列を小さくするための Character-level convolutional filters
- モデルサイズと計算コストを抑えながら性能も維持する、という問題に取り組むのが本研究の位置づけ
提案手法
2-Component Shared Embedding
- この論文の一番のミソ。
- ひとつの単語の分散表現を獲得するために、複数の単語で共有された2つのembeddingを利用する。
- 従来(下図左)は、語彙数が16の場合embeddingも16個(x_1, x_2, …, x_16)必要。”February”という語彙はx_2というembeddingで表現する。
- 提案手法では、語彙空間全体を2次元のテーブルで表現し、各行と各列に独立したembeddingを割り当てる。(下図中央)
- ひとつの語彙のembeddingを、その語彙の行と列の2つのembeddingで表現する。
- “February” は x^r_1 と x^c_2 の2つのembeddingで表現される。
- 各行と列のembeddingは複数の語彙のembeddingを表現するのに利用される。つまり共有されている。
- x^r_1 は “February”だけではなく、”January”にも利用されている。
- この仕組みにより、語彙数がVの場合、√V 個のembeddingですべての語彙のembeddingが表現可能。
- この語彙の割り当て方については後述。
RNN Model with 2-Component Shared Embedding
- 2-Component Shared Embeddingで語彙を表現するために基本的なRNNの構造(下図右)に手を入れる。
- この構造をLightRNNと名付ける(下図左)。
- シンプルに入力語彙W_(t-1)に対して該当するふたつのembedding、行ベクトル x^r_(t-1)と列ベクトル x^c_(t-1)のふたつが入力となる。それぞれに対して隠れ状態 h^r_(t-1)、h^c_(t-1)が存在する。
- 入力の次元数が語彙数のVから√V になっており、行と列のふたつの入力が必要となっていることから、入力の次元数は2√V となる。出力も同様。これにより大きくパタメタ数が小さくなる。
- 先ほど同様に、embeddingの次元数1024、語彙サイズ 10M、32ビット浮動小数点でモデルの大きさを概算すると、、
- 2 x 1024 x √(10^7) x 4バイト = 50MB
- あとはシンプルなRNNと同様
- それぞれの隠れ状態を該当する入力と隠れ状態を利用して計算
- t番目の語彙 w_t についての確率 P(w_t) は該当する行確率(P_r(w_t))と列確率(P_c(w_t))の積で算出する。
- ここの確率の計算も、次元がVから√V に削減していることで大きく効率化されている。
Bootstrap for Word Allocation
- 2次元で表現した単語の割り当てテーブル(word allocation table)の作成方法
- LightRNNで学習した語彙のembeddingに基づき、語彙の配置を繰り返し変更する以下の手順のブートストラップ手順
- 初期状態ではテーブルにランダムに語彙を割り当てる
- 与えられた割り当てテーブルに基づいてLightRNNでembeddingの学習を行う。
- 学習時間やPPLなどの停止基準を満たした場合は終了
- 満たさない場合は3に進む
- 前ステップで学習したembeddingを固定した上で、損失関数が最小になるように割り当てテーブル内の語彙を並び替える。その後2に戻る。
- 損失関数を最小化するアルゴリズム(完全には理解できてない)
- ある語彙の位置を別の位置に移動した際の損失をすべてのパターンにおいて計算しておいて、以下の最適化問題として解く。
- 損失のパターンとしては学習の過程で計算しているものも多い。
- standard minimum weight perfect matching problem(最小重み完全マッチング問題)に帰着する。
実験設定
- シンプルに言語モデリングタスクで性能評価を実施
- 評価指標
- perplexity(PPL)
- データセット
- ACLW(6つの言語ごとのデータセット)とBillionW(10億語を含むデータセット)を利用
- BillionWは特に前処理されていないので、出現頻度の極めて低い語彙のカットなど必要な前処理を適用。
- マシン
- すべての実験において GPU K20 with 5GB memory を1台
- https://mim-corp.co.jp/product/nvidia-tesla-k20
- 比較手法
- ACLWデータセットに対しては(当時)sotaのLSTMベースの手法で比較
- HSM: 語彙予測に階層的ソフトマックス
- C-HSM: 階層的ソフトマックス + 入力に文字レベルの畳み込みフィルタ
- BillionWデータセットに対してはRNNベースで以下を適用させるものを主に比較
- B-RNN: BlackOutを利用
- KN: Kenser-Ney 5-gramを利用
実験結果と考察
- embeddingの次元を変えた実験を実施
- 200, 600, 1000
- 提案手法ではモデル全体のサイズを大きく削減できるので、embeddingの表現力を高くしやすい。
- embeddingの次元数を大きくするとPPLは小さく(良く)なる。モデルの大きさはさほど大きくない。
- ACLWデータセットにおけるモデルサイズとPPLの比較
- LightRNNはモデルサイズがかなり小さいが、PPLについてもベースラインを上回る。
- 語彙数の増加に対してもロバスト。たとえばFrenchはEnglishの2倍の語彙数であり、ベースラインは線形にモデルサイズが大きくなっているが提案手法はほぼ同等の大きさ。
- BillionWデータセットにおけるモデルサイズとPPLの比較
- 「KN + 」はアンサンブル
- LightRNNはモデルサイズを小さく保ちつつPPLも良い
- モデルサイズは1/100に!
- table-refinementsの回数とPPLの遷移
- 3,4回くらいでsotaを超える。
- 2つのデータセットにおいてベースライン(C-HSM)と同じ性能を達成するのに必要な学習時間
- 提案手法では半分ほどの実行時間でベースラインを達成している。
- rellocationが占める時間は非常に小さい。
- 単語テーブルを見てみると、行ごとに同じような意味の単語が集まっていることが見られた。
- 832は場所、889は時間的概念、887はURL
- 考察:意味的に近いものが同じ行に集まり embeddingを共有することで、ベースラインよりも高い性能が出ているのではないか。
感想
- 「2-Component Shared Embedding」シンプルに賢くて好き。単語割り当てテーブルの配置最適化も現実的に解ける形になっていて良き。
- 今回は言語モデルとしての評価のみであったが、下流タスクにおいてどのような性能になるのかは気になり。
- k=2ではなくもっと多くのembeddingに分けるとどうなるのか気になり。研究されてそうだけど調べてられていない。