Mambaという深層学習モデルについて

こんにちは,バックエンドの藤岡です.

私は現在大学4年(あと数週間で大学院生ですが...)で,深層学習分野の研究をしています.PlayGroundには大変知識豊かな方々が沢山おり,このTechブログ等でその知見を共有されていますが,私も日々論文を読んでいて大変面白かったものの1つである "Mamba: Linear-Time Sequence Modeling with Selective State Spaces"という論文を通じて深層学習の知見を共有できたらと思います.

Mambaは深層学習界隈ではある程度有名なため,既にインターネット上には大変わかりやすい日本語解説ブログが数多く出回っています.したがって,今回はあくまで紹介ということで,できるだけ噛み砕いた"雰囲気"をお伝えできればと思います.

(詳細や正確な情報が知りたい方は,是非とも元論文を読んでみることをオススメします.)

背景

現在,Chat GPTなど深層学習を用いたアプリケーションの多くはTransformerという深層学習モデルをもとに構築されています.しかしこのTransformerというモデルはあまり計算効率が良くなく,今までこの計算効率を改善しようと様々なモデルが研究されていますが,現状の結果は芳しく有りませんでした.

筆者らは,これら改善を試みたモデルの弱点は"内容を十分に把握して推論できていないため"ではないかと考え,これらモデルの1つであるSSMベースのモデルに以下の2点の改善を加えました.

  1. モデルの入力から必要な情報を保持したり不必要な情報を忘却したりと,より柔軟に入力情報を扱えるようにした
  2. GPUでの計算アルゴリズムを,より効率的なものへの再構築した

これら改善を取り入れたモデルを,筆者らはMambaと呼んでいます.

改善点1:入力情報の扱いの改善

先ほど紹介したTransformerや筆者らが改善を加えたSSMといった,時系列情報を扱う深層学習モデルは,入力された情報をどれだけモデル内で小さな情報に圧縮して扱えるかで「計算効率」が左右されます.そして,この「計算効率」は「モデルの性能」とトレードオフの関係となっています.ここでの「計算効率」は計算時間,必要なGPU量に関係するもので,「モデルの性能」は文字通りどれだけ性能の良いモデルかというものです.

つまりざっくりいうと,モデルの入力情報を一切圧縮せずに処理したら性能は良くなるが計算効率は最悪,モデルの入力を超圧縮すると性能はイマイチだが計算効率は抜群,といった感じです.ちなみに前者はTransformer,後者はRNNが近いですね.

故に,効率的かつ高性能なモデルにするには,入力情報から必要なデータのみを選択してモデル内で保持しておけばよいとわかります.

そこで筆者らは,従来モデルのS4のパラメータを入力に依存するように修正することで,この入力情報から必要なデータのみを選択して保持しておく仕組みを実現しました.(具体的にどのようにして改良したかは以下に簡単に説明しますが,少し専門用語を使うので読み飛ばしてもらっても構いません.)

改善方法

(図は元論文より)

筆者らはS4のパラメータのいくつかを入力シーケンス依存とすることで,データの選択メカニズムを構築しました.主な変更点は以下の2点です.

  1. Δ,B,Cがパラメータから入力についての関数になった
  2. テンソルの形状が総じてシーケンス長Lをもつように変化した

特筆して,パラメータがシーケンス長次元Lを持つようになったため,モデルが時不変から時変に変わりました(以下補足).そのため,S4の畳み込みの特性(計算効率を良くするためのテクニック)が失われますが,その点を以下の改善点2でカバーします.

補足:SSMsのLTI(線形時間普遍性)について

SSMのダイナミクス(Δ,A,B,C)は時間(ここだと系列番号に該当)で不変な定数であり,LTIとなっています.この特性からSSMsの再帰や畳み込みの性質が担保されています.(詳細はS4を参照)そのため,これまでのSSMsはすべて計算効率の観点からLTIでした.しかし,本研究ではパラメータ(Δ,B,C)が入力依存となったことでシステムが時変系となりました.閑話休題.

改善点2:GPUアルゴリズムの効率化

改善点1で施した仕組みは従来のSSMモデルの性能を飛躍させうる非常に見込みのあるものですが,その反面,SSMモデルの前提を崩すような変更です.この前提の瓦解は計算効率を大幅に悪化させます.

そこで,筆者らはGPUの並列計算のアルゴリズムをハードウェアを意識したものに変更して,この計算効率の悪化をカバーしました.具体的には,莫大なメモリ使用量に対処するため,潜在変数をメモリ階層のより効率的なSRAM階層のみにマウントするように変更しました.これにより,メモリIOを削減し,通常の実装よりも大幅な高速化が実現できます.

結果

Mambaはこれらの改善を踏まえた上でいくつかの実験を実施し,その性能を評価しました.

その結果,Mambaは言語,オーディオ,ゲノミクスなど様々な分野において最先端の性能を達成したと報告しています.言語についてもMamba-3B(Mambaの改良モデルのようなもの)は同じサイズのTransformerモデルを上回ったと報告しています.

最後に

今回は深層学習分野の中から,最近注目を浴びているMambaという論文を紹介しました.

実はこの論文は,ICLR 2024という深層学習における世界トップクラスの学会に投稿されたものの残念ながらReject(不採用)されてしまいました.😭

Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many...

それでも現在(2024/03/22)本アーカイブ論文(査読等が通っていない信ぴょう性の低い論文)では引用117件がついており,その注目度が伺えます.

このようなスピード感あふれる(?)研究動向は深層学習ならではではないかと思います.それでは今回はこの辺で.