Mamba: Linear-Time Sequence Modeling with Selective State Space Models
Introduction
Foundation Models (FMs) have emerged as effective paradigm in modren ML. The backbone of these FMs are often sequence models, predominantly based on the Transformer architecture. However, the quadratic complexity of self-attention mechanism in Transformers limits their ability to handle long sequences efficiently. To address this, various efficient Transformer variants have been proposed, but they often compromise on performance.
SSMs are a promising alternative for sequence modeling, capable of capturing long-range dependencies. However, existing SSM-based models like S4 have been less effective at modelling discrete and information-dense data such as text. Traditional SSMs are also time-invariant, which limits their ability to adapt to varying contexts in sequences.
This paper proposes selective state space models, which introduce time-variant dynamics to SSMs, allowing them to adapt to different parts of the input sequence while scaling linearly in sequence length. To solve the key limitation of SSMs, inability of efficiently select data in input-dependent manner, a simple selection mechanism by parameterizing based on the input is introduced. But this makes computation challenging, can't use fast convolution trick, so hardware-aware algorithm is designed to efficiently compute the model.
State Space Models
S4 models are defined with four parameters \((\Delta, A, B, C)\), which define a sequence-to-sequence transformation in two stages.
- The first stage transforms the continuous parameters \((\Delata, A, B)\) to discrete parameters \((\bar{A}, \bar{B})\) through fixed formulas \(\bar{A} = f_A(\Delta, A)\) and \(\bar{B} = f_B(\Delta, A, B)\). \(f_A, f_B\) are called discretization rule, various rules exist, bilinear rule is used in S4. Alternate flavors of SSMs can bypass discretization step by directly parameterizing \(\bar{A}, \bar{B}\).
- After discretization, the model can be computed in two ways, either as a linear recurrence or as global convolution. Recurrence is efficient for inference and convolution is efficient for training.
LTI: Linear Time-Invariant SSMs have fixed parameters \((\Delta, A, B, C)\) for all inputs. The same transformation is applied to all inputs. This limits the model's ability to adapt to different contexts in the sequence.
Selective State Space Models
Motivation
The fundamental problem of sequence models is compressing context into a smaller state. Attention mechanism is effective because it doesn't explicitly compress context, but it is also inefficient because it requires quadratic computation and memory. On the other hand, recurrent models are opposite and their effectiveness is limited by how well this state has compressed the context. So the fundamental principle of building a sequence models is selectivity.
To understand this principle the authors use Selective Copying and Induction Heads as diagnostic tasks to make their argument
- Selective Copying tells “Can you filter noise and only keep the relevant tokens?”
- Induction Heads tells “Can you retrieve and use past associations when the context calls for it?”
In normal LTI models, the recurrent view fails at these tasks because of their constant dynamics ($\bar{A}, \bar{B}) cannot let them select the correct information from their context. From, convolution view, the global convolutions can solve vanilla Copying, but fails at Selective Copying task because of lack of content-awareness.
Improving SSMs with Selection
The parameters \(\Delta, B, C\) are reparameterized to be input-dependent, allowing the model to adapt to different parts of the input sequence. The state matrix \(A\) is kept fixed to maintain the efficiency of the model. The new parameters are defined as: \(s_B(x) = Linear_N(x),~ s_C(x) = Linear_N(x),~ s_{\Delta}(x) = Broadcast_D(Linear_1(x)) \text{ and } s_{\Delta} = softplus\)
Implementation Details
But the time-variant parameters make it impossible to use the fast convolution trick, so a hardware-aware algorithm is designed to efficiently compute the model. The key idea is to batch the computation of multiple different convolutions together using block diagonal matrices and parallel scans.