yoheikikuta/paper-reading

[1901.02860] Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context [paper-reading]

yoheikikuta opened this issue · 8 comments

論文リンク

https://arxiv.org/abs/1901.02860

公開日(yyyy/mm/dd)

2019/01/09

概要

transformer などの fixed length のモデルだと longer dependence を取り込むことができないという問題がある。それを回避するために fixed segment length を前の segment の特徴量も取り込む形にしてスライドしながら適用させ、それを実現するために relative positional encoding を提案。前方に connection を貼ってスライドしていくので original の transformer の拡張というモデルになっている。
relative positional encoding では trainable だが位置によらない量で Q^T の positional embedding 部分を置き換え、trainable でない量で相対位置 i-j を embed する正弦波 embedding で K の positional embedding 部分を置き換えている。
実験で性能向上を示し、appendix では生成文章のサンプルが結構しっかりと載せられていてなかなか面白い。

BERT を色々といじったりしていて、入力の token 数が限定されているというのが気になることがある。
特に長い文章の特徴量を抽出したいというときに一部しか使えないので、その辺りを改善しているモデルを調べてみたいと思ってこの論文を読んでみる。

論文の押しているポイントとしても RNN の 80%、vanilla Transformers の 450% 長い依存性を学習できるようになったというところで、どうやってそれを実現しているかを見ていく。

入力が fixed length のモデルであっても、それをスライドさせて使っていけば長い文章に適用することは可能ではある。

例えば、1024 token からなる文章があったとき、入力 token が最大 512 token であるモデルであれば、 1-512 を入力にしてそれぞれの token の特徴量を得て、その後 513-1024 を入力にしてそれぞれの token の特徴量を得ればよい。
ただしこの方法ではスライドさせたときに位置情報が意味を成さなくなるし、言語モデルのようにある単語を基に次の単語を予測するという sequential なものも実現できない。
こういうことをするならほとんど単純な word2vec と同じだろう。

もう少し使えるものにするには token を一つずつずらして適用していくパターンである。
これがこの論文で言うところの vanilla transformer になっていて、イメージは以下の図である。
この場合、言語モデルのような振る舞いをさせることはできるが、やはり位置情報の意味はなくなる(例えば最初の適用で二番目だった token は、次の適用では一つスライドする結果一番目の token になるので、positional encoding は破綻している)。

もう一つの問題は、文章の意味的なつながりを考えずに扱える token 長で chunking していくという点だ。
例えば学習の時、上の図では Segment 1 と Segment 2 は別個の instance として学習していく。しかし本来はつながっている文章なので Segment 2 の特に最初の方の token は Segment 1 の後ろの方の token 関係性が強いはず。それを考慮できてないのでよろしくない、という問題である。
この問題を論文では context fragmentation と呼んでいる。予測の時にはスライドして適用していくのでこの問題は緩和はするが、学習と予測でやっていることが違うという本質的な問題は残る。

この辺りを解決する Transformer のモデルが欲しい、という話である。

主たるアイデアは以下の二つとなる。

  • self-attention を使ったモデルで recurrent に適用していく機構を使用
  • relative positional encoding の導入

これを聞くと別段特別なアイデアではないような感じがするが、実際なにか素晴らしい発明をしたという類の論文ではないと思う。

以降では、言語モデル $ P(x) = \prod P(x_t | x_1, ... , x_t-1) \ \text{where} \ x = (x_1, ..., x_T) $ をターゲットのタスクにしてどのようにモデリングをしていくか見ていく。

まずは segment level での recurrent 構造から。
以下の図を見れば大体の雰囲気は理解できる。

まずは学習の方から。
ポイントは緑色になっている connection でこれは segment 間をまたいでいる。これによって segment 間を関係付けることができるが、これはあくまで forward で使うのみで back prop では重みを更新しない。
つまり segment 毎にスライドしていくが、学習するのは一番新しい segment のみ、ということになっている。
また、次の segment での計算で前の segment での計算結果が使えるのでそれは cache しておくことで高速化が可能となる。

式で書けば以下のようになる。SG は Stop Gradient で gradient を流さないことを意味する。上で書いたようにこれは緑色の path で表現されている。

次いで予測の方。
これは学習の方が分かっていれば難しいことはなく、使える計算結果は cache してスライドしながら適用していけばよい。ここでの緑色の領域は path でつながっているのでこのコンテキストを利用して予測ができるという意味で、vanilla transformer では fixed length しか使えなかったが、このモデルでは segment をまたいで link があるので長いコンテキストを考慮できることが違いとなる。

長いコンテキストを使える、ということ以外に途中計算をキャッシュできるので、予測時に vanilla transformer と比べて 1800 倍速いと言っている。まあこれは vanilla transformer の計算が非効率なのでそれはそう、という感じ。

recurrent 構造に関してはこれで良いが、positional encoding の問題が残っている。
スライドして適用させていく以上、絶対的な位置を encode した特徴量を使うわけにはいかない。そこで相対的な位置情報を使うというアイデアを導入する。

positional encoding はある種のバイアスをモデルに与えるものであり、relative positional encoding では key(位置 i) と query(位置 j) の相対的な位置 i-j を依存するバイアスを与えるというアイデアである。

このアイデア自体はこの論文が初ということではないが、この論文ではこれまでとは少し違うが実験でもっとも結果の良いものを構築したとのこと。

まずは absolute positional encoding がどのように書けるかを復習。
token i, j の attention score の計算は以下のようになる。ここで、U が positional encoding で E が word embedding である。Query^T * (something) * Key の計算で something に行列をかませているというものになっている。

これを relative positional encoding へと変換しよう。
当然やり方は一意ではないが、以下のものを提案している。

まず R は relative な positional embedding で、これは学習なしの正弦波のものを用いる。original の transformer 論文のやつを i-j にしたもの。

また、positional embedding の transpose の項である (c) と (d) は学習パラメタ u, v に置き換えている。これはすなわち query が(この二つの項に限ってたが)位置によらず同一であることを意味しており、上のバイアスの議論を思い出せば query の位置によらずに self-attention のバイアスは同じにするというヒューリスティックを取り込んでいることになる。これの妥当性がどれくらいかはちょっと難しいところだが、既存研究の encoding と実験で比べてこちらの方が有効だった、というくらいに読める。

さらに、$ W_k $ を $ W_{k,E} $ と $ W_{k,R} $ の二つに置き換えていて、content-based なものと location-based な key vector を別個のものとして取り扱えるようにしている。

以上を踏まえてモデルの全体像を記述したものが以下。
この relative positional encoding とか cache の仕組みは実装上は少し面倒で気をつけなければならないが、理論的にはここまでの議論で尽きているのでそこまで難しいものではない。

あとは実験で比較という流れなのだが、性能評価には bits per character (bpc) と ppl (perplexity) が使われている。これは言語モデルの評価などでよく使われているものだが、ちょっと真面目に取り上げてみる。

まずは ppl から。
これは一言で言えば言語モデルの幾何平均である。$ \prod (1 / p(w_t | w_t-1, ... , w_1))^1/T $
言語モデルが良いモデルであれば、 $ p(w_t | w_t-1, ... , w_1) $ で正解の token を出す確率は 1 に近くなるので perplexity は 1 に近くなる。一方で例えばランダムの場合は 1/(可能な token の数) となってかなり大きな数となる。
書き方としてはよくこれを $ 2^{ - 1/T \sum log p} $ のように 2 の肩に上げたりする。経験分布 $ p = n / T $ (n は出現頻度) を用いて肩の部分は cross entropy とすることもできる。
こうしてみれば至極自然だが、適切な説明がないまま 2 の entropy 上ですとかいう謎の説明に出くわしたりするので注意が必要だ。

ついで bpc だが、こっちは調べても NLP の文脈ではいまいち明確に書いてあるものが少ない気がする。
情報理論で考えれば、これは言語モデルを使って information content を計算してその平均を取る、ということですよねきっと。なので当然これも小さい方が良い。

ということで実験結果をまとめてドン。良いですという結果。

ablation study の結果も一つくらい載せておく。

ちゃんと改良してそれを実験でも示しました、という結果になっている。モデリングがそこまで novel ではないのが着実に進歩している。

面白いのは appendix に生成した文章のサンプルが載せてある点である。
長い文章でもコンテキストをちゃんと保持して出力されている様子や、日付をちゃんと順番通りに正しく表示しながら文章を出力している様子が見て取れる。これはなかなか眺めていて面白い。

ということで中身は大体理解した。

実装 https://github.com/kimiyoung/transformer-xl は軽くしか眺めてないが、必要があれば真面目に読んでいこう。