Mamba: Linear-Time Sequence Modeling with Selective State Spaces¶
这篇论文在前段时间非常火,然后最近学习了一下。网上有不少参考资料,为了方便查找列在下面。本文主要目的是记录一些好的参考资料以及some notes。
首先看这篇论文前应该对Transformer和RNN已经有比较深的学习了。S4 (Structured State Space for Sequence Modeling)模型我之前没有接触过,故先学习一下S4模型。
1 SSM (State Space Models):状态空间模型¶
空间状态模型就是用于描述不同的状态的表示,并根据某些输入预测下一个状态是什么的模型。
SSM包括输入序列\(x(t)\),hidden state \(h(t)\),预测输出序列\(y(t)\)。这里的序列**都是连续序列,而非离散**。
State equation就是\(h^{'}(t)=A\cdot h(t) + B\cdot x(t)\);Output equation就是\(y(t)=C\cdot h(t)+D\cdot x(t)\)
2 从SSM到S4¶
从SSM升级到S4主要有三个方面的升级:离散化、卷积表示、HiPPO。
2.1 SSM的离散化表示¶
这里引入一个符号\(\Delta\)表示步长,也就是输入的阶段性保持(resolution),也可以理解成“采样间隔”。这样对于上面的矩阵A、B可以得到其离散化表示: $$ \overline{\boldsymbol{A}}=\exp (\Delta \boldsymbol{A}) \quad \overline{\boldsymbol{B}}=(\Delta \boldsymbol{A})^{-1}(\exp (\Delta \boldsymbol{A})-\boldsymbol{I}) \cdot \Delta \boldsymbol{B} $$
2.2 卷积表示¶
目前看到的式子\(y_k\)依赖\(h_k\),而\(h_k\)依赖\(h_{k-1}\)。如果我们把式子展开会得到: $$ \begin{aligned} y_2 & =C h_2 \ & =C\left(\bar{A} h_1+\bar{B} x_2\right) \ & =C\left(\bar{A}\left(\bar{A} h_0+\bar{B} x_1\right)+\bar{B} x_2\right) \ & =C\left(\bar{A}\left(\bar{A} \cdot \bar{B} x_0+\bar{B} x_1\right)+\bar{B} x_2\right) \ & =C\left(\bar{A} \cdot \bar{A} \cdot \bar{B} x_0+\bar{A} \cdot \bar{B} x_1+\bar{B} x_2\right) \ & =C \cdot \bar{A}^2 \cdot \bar{B} x_0+C \cdot \bar{A} \cdot \bar{B} \cdot x_1+C \cdot \bar{B} x_2 \end{aligned} $$ 可以发现这其中的规律,于是我们可以考虑用卷积去“一步到位”地求解出想要的\(y_k\)。这里的卷积核不像CNN中的,这里的卷积核是一维的,如下:
这样子SSM在训练的时候可以转化为类CNN形式,可以并行训练;而推理时RNN更快,则在推理时采用RNN。这说明SSM是可以在RNN和CNN结构之间相互转化的。
2.3 HiPPO¶
2.4 总结¶
Note
- 没有non linear的函数,如relu等激活函数。这说明back propagation时候传播到每个元素的过程可以用一个矩阵乘法操作来表示,而不像RNN需要step by step形式。
- No time dependence. 不同time step对应的矩阵是一样的。