In the previous two articles, "Revisiting SSM (Part 1): Linear Systems and HiPPO Matrices" and "Revisiting SSM (Part 2): Some Legacy Issues of HiPPO", we introduced the core ideas and derivations of HiPPO—real-time approximation of continuously updated functions through orthogonal function bases. The dynamics of the fitting coefficients can be expressed as a linear ODE system, and for specific bases and approximation methods, we can precisely calculate the key matrices of the linear system. Furthermore, we discussed the discretization and related properties of HiPPO, which laid the theoretical foundation for subsequent SSM work.
Next, we will introduce the subsequent application paper of HiPPO, "Efficiently Modeling Long Sequences with Structured State Spaces" (referred to as S4). It utilizes the derivation results of HiPPO as a basic tool for sequence modeling and explores efficient computation and training methods from a new perspective. Finally, it validates its effectiveness on many long-sequence modeling tasks, making it one of the representative works of the SSM and even the RNN renaissance.
Basic Framework
The sequence modeling framework used by S4 is the following linear ODE system: \begin{equation} \begin{aligned} x'(t) =&\, A x(t) + B u(t) \\ y(t) =&\, C^* x(t) + D u(t) \end{aligned} \end{equation} Here u, y, D \in \mathbb{R}; x \in \mathbb{R}^d; A \in \mathbb{R}^{d \times d}; B, C \in \mathbb{R}^{d \times 1}. The symbol {}^* denotes the conjugate transpose operation; if the matrices are real, it is simply the transpose. Since a complete model usually includes a residual structure, the last term D u(t) can be integrated into the residual. Therefore, we can directly assume D=0 to simplify the form slightly without reducing the model’s capacity.
This system possesses similarity invariance. If \tilde{A} is a matrix similar to A, i.e., A = P^{-1}\tilde{A}P, then substituting and rearranging gives: \begin{equation} \begin{aligned} Px'(t) =&\, \tilde{A} Px(t) + PB u(t) \\ y(t) =&\, ((P^{-1})^* C)^* P x(t) \end{aligned} \end{equation} By treating Px(t) as a whole to replace the original x(t), the transformation of the system is (A, B, C) \to (\tilde{A}, PB, (P^{-1})^*C), but the output remains completely unchanged. This means that if there exists a similar matrix \tilde{A} of A that makes computation simpler, we can analyze the system entirely within \tilde{A} without changing the results. This is the core idea behind the subsequent series of analyses.
In particular, S4 selects the matrix A as the HiPPO-LegS matrix, namely: \begin{equation} A_{n,k} = -\left\{\begin{array}{ll}\sqrt{(2n+1)(2k+1)}, &k < n \\ n+1, &k = n \\ 0, &k > n\end{array}\right. \end{equation} The special thing about this choice is that the ODE we previously derived for LegS was of the form x'(t) = \frac{A}{t} x(t) + \frac{B}{t} u(t), while the ODE for LegT was of the form x'(t) = A x(t) + B u(t). So now, the LegT-style ODE is paired with the LegS A matrix. Therefore, the first question to ask is: what impact does such a combination have? For example, is its memory of history still complete and equal-weighted like LegS?
Exponential Decay
The answer is no—the ODE system chosen by S4 has an exponentially decaying memory of history. We can understand this from two perspectives.
The first perspective starts from the transformation discussed in "Revisiting SSM (Part 2): Some Legacy Issues of HiPPO". The LegS-type ODE can be equivalently written as: \begin{equation} Ax(t) + Bu(t) = t x'(t) = \frac{d}{d\ln t} x(t) \end{equation} By setting \tau = \ln t, the LegS-type ODE can be transformed into a LegT-type ODE with time variable \tau, which is the ODE used by S4. We know that LegS treats every part of history equally, but this assumes the input is u(t) = u(e^{\tau}). However, S4’s ODE is equivalent to the input being directly u(\tau). If we perform uniform discretization on \tau, the weights at each point will not be equal. Assuming t \in [0, T], written in terms of probability density, dt/T = \rho(\tau)d\tau, which means \rho(\tau) = e^{\tau}/T. Thus, the weight is an exponential function of \tau, where more recent history has a larger weight.
The second perspective requires a bit more linear algebra. Also in "Revisiting SSM (Part 2): Some Legacy Issues of HiPPO", we mentioned that the HiPPO-LegS matrix A can theoretically be diagonalized, and its eigenvalues are [-1, -2, -3, \dots]. Thus, there exists an invertible matrix P such that A = P^{-1}\Lambda P, where \Lambda = \text{diag}(-1, -2, \dots, -d). According to similarity invariance, the original system is equivalent to the new system: \begin{equation} \begin{aligned} x'(t) =&\, \Lambda x(t) + PB u(t) \\ y(t) =&\, C^* P^{-1} x(t) \end{aligned} \end{equation} After discretization (using Forward Euler as an example): \begin{equation} x(t+\epsilon) = (I + \epsilon\Lambda) Px(t) + \epsilon P B u(t) \end{equation} Here I + \epsilon\Lambda is a diagonal matrix where each component is less than 1. This means that with each iteration, the historical information is multiplied by a number less than 1. After multiple steps, this results in an exponential decay effect.
Discretization Schemes
Although exponential decay seems less elegant than LegS’s equal treatment of history, there is no free lunch. For a fixed-size memory state x(t), as the memory interval grows larger, LegS’s approach of treating every part of history equally actually leads to every part of history becoming blurred. For scenarios that follow the "near is large, far is small" principle, this is counterproductive. Additionally, the right side of the S4-type ODE does not explicitly feature time t, which helps improve training efficiency.
Having understood the memory properties of the S4-type ODE, we can proceed to the next step. To handle discrete sequences in practice, we first need to discretize. In the previous article, we provided two high-precision discretization schemes. One is the bilinear form: \begin{equation} x_{k+1} = (I - \epsilon A/2)^{-1}[(I + \epsilon A/2) x_k + \epsilon B u_k] \end{equation} It has second-order accuracy. S4 adopts this discretization scheme, which is also the format explored in the rest of this article. The other is based on the exact solution of the ODE with constant input, yielding: \begin{equation} x_{k+1} = e^{\epsilon A} x_k + A^{-1} (e^{\epsilon A} - I) B u_k \end{equation} The author’s subsequent works, including Mamba, use this format. In this case, it is generally assumed that A is a diagonal matrix, because for the LegS matrix A, the matrix exponential is not easy to compute.
Now we denote: \begin{equation} \bar{A}=(I - \epsilon A/2)^{-1}(I + \epsilon A/2),\quad\bar{B}=\epsilon(I - \epsilon A/2)^{-1}B,\quad\bar{C}=C \end{equation} Then we obtain a linear RNN: \begin{equation} \begin{aligned} x_{k+1} =&\, \bar{A} x_k + \bar{B} u_k \\ y_{k+1} =&\, \bar{C}^* x_{k+1} \\ \end{aligned} \label{eq:s4-r} \end{equation} where \epsilon > 0 is the discretization step size, which is a manually chosen hyperparameter.
Convolutional Operations
In the previous article, we also mentioned that the HiPPO-LegS matrix A is computationally efficient. Specifically, multiplying A or \bar{A} by a vector x has a computational complexity of \mathcal{O}(d) rather than the general \mathcal{O}(d^2). However, this only means that the recursive calculation in Eq. [eq:s4-r] is more efficient than a general RNN. To achieve efficient training, simple recursion is not enough; we need to explore parallel computing methods.
There are two ideas for parallelizing linear RNNs: one is to treat it as a Prefix Sum problem, as introduced in "Google’s New Work Attempts to ’Revive’ RNNs: Can RNNs Shine Again?", using Associative Scan algorithms such as Upper/Lower, Odd/Even, or Ladner-Fischer. References can be found in "Prefix Sums and Their Applications". The other is to transform it into a convolutional operation between a matrix sequence and a vector sequence, utilizing the Fast Fourier Transform (FFT) for acceleration. This is the approach taken by S4. Regardless of the method, they face a common bottleneck: the calculation of the matrix power \bar{A}^k.
Specifically, we usually set the initial state x_0 to 0, which allows us to write: \begin{equation} \begin{aligned} y_1 =&\, \bar{C}^*\bar{B}u_0\\ y_2 =&\, \bar{C}^*(\bar{A}x_0 + \bar{B}u_1) = \bar{C}^*\bar{A}\bar{B}u_0 + \bar{C}^*\bar{B}u_1\\ y_3 =&\, \bar{C}^*(\bar{A}x_1 + \bar{B}u_2) = \bar{C}^*\bar{A}^2 Bu_0 + \bar{C}^*\bar{A}Bu_1 + \bar{C}^*\bar{B}u_2\\[5pt] \vdots & \\ y_L =&\, \bar{C}^*(\bar{A} x_{L-1}+\bar{B}u_{L-1}) = \sum_{k=0}^{L-1} \bar{C}^*\bar{A}^k \bar{B}u_{L-k} = \bar{K}_{< L} * u_{< L} \end{aligned} \end{equation} where * denotes the convolution operation, and \begin{equation} \bar{K}_k = \bar{C}^*\bar{A}^k\bar{B},\quad \bar{K}_{< L} = \big(\bar{K}_0,\bar{K}_1,\dots,\bar{K}_{L-1}\big),\quad u_{< L} = (u_0,u_1,\dots,u_{L-1}) \end{equation} Note that by current convention, \bar{C}^*\bar{A}^k \bar{B} and u_k are scalars, so \bar{K}_{< L}, u_{< L} \in \mathbb{R}^L. Convolution can be converted into frequency-domain multiplication via the (Discrete) Fourier Transform and then transformed back, with a complexity of \mathcal{O}(L \log L). Although this complexity is theoretically higher than the \mathcal{O}(L) of direct recursion, the Fourier Transform is highly parallelizable, making it faster in practice.
The problem now is how to efficiently compute the convolution kernel \bar{K}_{< L}, which requires computing the power matrix \bar{A}^k. Computing this by definition is very expensive. While \bar{A} is a constant matrix (given \epsilon), and its powers could theoretically be precomputed, \bar{C} and \bar{B} are trainable parameters in S4. Thus, \bar{C}^*\bar{A}^k\bar{B} cannot be fully precomputed.
Generating Functions
Before further analysis, let’s introduce the concept of generating functions, which is a fundamental step for efficient computation.
For a given sequence a = (a_0, a_1, a_2, \dots), its generating function is a power series constructed using the sequence components as coefficients: \begin{equation} \mathcal{G}(z|a) = \sum_{k=0}^{\infty} a_k z^k \end{equation} If we have two sequences a and b, the product of their generating functions is: \begin{equation} \mathcal{G}(z|a)\mathcal{G}(z|b) = \left(\sum_{k=0}^{\infty} a_k z^k\right)\left(\sum_{l=0}^{\infty} b_l z^l\right) = \sum_{l=0}^{\infty}\left(\sum_{k=0}^l a_k b_{l-k}\right) z^l \end{equation} The coefficient of z^l in \mathcal{G}(z|a)\mathcal{G}(z|b) is exactly the convolution of a_{< l+1} and b_{< l+1}. If we can efficiently evaluate the generating function and extract its coefficients, we can transform convolution into simple multiplication.
The Discrete Fourier Transform (DFT) is exactly such a method. If we only need the first L terms of the convolution, we can truncate the sum at L-1. DFT evaluates the generating function at specific points z=e^{-2i\pi l/L}, l=0,1,\dots,L-1: \begin{equation} \hat{a}_l = \sum_{k=0}^{L-1} a_k e^{-2i\pi kl/L} \end{equation} The inverse transform (IDFT) is: \begin{equation} a_k = \frac{1}{L}\sum_{l=0}^{L-1} \hat{a}_l e^{2i\pi kl/L} \end{equation} Both can be computed efficiently via the Fast Fourier Transform (FFT). To avoid cyclic convolution artifacts, we typically pad the sequences with L zeros to double the period.
From Powers to Inverses
For the convolution kernel \bar{K}, we have: \begin{equation} \mathcal{G}(z|\bar{K}) = \sum_{k=0}^{\infty} \bar{C}^*\bar{A}^k \bar{B}z^k = \bar{C}^*\left(I - z\bar{A}\right)^{-1}\bar{B} \label{eq:k-gen} \end{equation} The generating function transforms the calculation of the power matrix \bar{A}^k into the calculation of the matrix inverse (I - z\bar{A})^{-1}.
What kind of matrix \bar{A} makes (I - z\bar{A})^{-1} easy to compute? Diagonal matrices are ideal. If \bar{A} is diagonalizable as \bar{A} = P^{-1}\bar{\Lambda}P, then: \begin{equation} (I - z\bar{A})^{-1} = P^{-1}(I - z\bar{\Lambda})^{-1}P \end{equation} Can \bar{A} be diagonalized? This depends on A. If A = P^{-1}\Lambda P, then \bar{A} is also diagonalized by P. Theoretically, almost all matrices are diagonalizable in the complex field, and we know the eigenvalues of the LegS matrix A are [-1, -2, -3, \dots]. However, in practice, the matrix P for HiPPO-LegS is numerically unstable due to finite precision.
Eigenvectors
Diagonalizing A is equivalent to finding its eigenvectors. For a known eigenvalue \lambda, we solve (A + \lambda I)v = 0. Using the expression for Av from the previous article: \begin{equation} (Av)_n = n v_n -\sqrt{2n+1}\sum_{k=0}^n \sqrt{2k+1}v_k \end{equation} Setting -Av = \lambda v leads to a recurrence for S_n = \sum_{k=0}^n \sqrt{2k+1}v_k: \begin{equation} S_{n-1} = \frac{\lambda - n - 1}{\lambda + n}S_n \end{equation} By analyzing this recurrence, one can show that for n \approx \lambda \approx d/3, the components of the eigenvector decay exponentially: \begin{equation} |S_{d/3}| \sim \mathcal{O}(\sqrt{d}\,2^{-4d/3}) \end{equation} This exponential decay (or explosion) makes the matrix P extremely ill-conditioned, leading to numerical instability.
Diagonal Plus Low-Rank
If \bar{A} cannot be easily diagonalized, can we use other structures? The Woodbury identity is: \begin{equation} (M - UV^*)^{-1} = M^{-1} + M^{-1}U(I - V^*M^{-1}U)^{-1} V^*M^{-1} \end{equation} If M is easy to invert (e.g., diagonal) and U, V are low-rank, the inverse is easy to compute. The HiPPO-LegS matrix A can be written as: \begin{equation} A_{n,k} = \left\{\begin{array}{ll}n\delta_{n,k} - \sqrt{2n+1}\sqrt{2k+1}, &k \leq n \\ 0, &k > n\end{array}\right. \end{equation} This looks like "diagonal minus low-rank," except for the lower-triangular constraint.
The Finishing Touch
The authors of S4 used a brilliant trick. Consider A + \frac{1}{2}vv^* + \frac{1}{2}I, where v = [1, \sqrt{3}, \sqrt{5}, \dots]^*. This matrix turns out to be antisymmetric: \begin{equation} \left(A + \frac{1}{2}v v^*+\frac{1}{2}I\right)_{n,k} = \left\{\begin{array}{ll} - \frac{1}{2}\sqrt{2n+1}\sqrt{2k+1}, &k < n \\ 0, &k=n \\ \frac{1}{2}\sqrt{2n+1}\sqrt{2k+1}, &k > n\end{array}\right. \end{equation} Antisymmetric matrices are always diagonalizable by unitary matrices. Unitary matrices are numerically stable! Thus, we can decompose A as: \begin{equation} A = U^*(\Lambda - uv^*) U \end{equation} where \Lambda is diagonal and U is unitary. This "Diagonal Plus Low-Rank" (DPLR) structure allows for \mathcal{O}(d) matrix-vector multiplication and efficient inversion.
Final Sprint
With A = U^*(\Lambda - uv^*) U, we can transform the system to the DPLR basis. The generating function becomes: \begin{equation} \mathcal{G}(z|\bar{K}) = \frac{2}{1+z}\bar{C}^* \left[\frac{2}{\epsilon}\frac{1-z}{1+z}I - A\right]^{-1}B \end{equation} Substituting the DPLR form of A and applying the Woodbury identity: \begin{equation} \mathcal{G}(z|\bar{K}) = \frac{2}{1+z}\bar{C}^* \left[R_z^{-1} - R_z^{-1}u(I + v^*R_z^{-1}u)^{-1} v^*R_z^{-1}\right]B \end{equation} where R_z = \frac{2}{\epsilon}\frac{1-z}{1+z}I - \Lambda is diagonal. This allows us to compute the DFT of the kernel \bar{K} efficiently. To avoid computing \bar{A}^L during training, S4 treats certain terms involving \bar{C} as learnable parameters.
Conclusion
S4 is a significant extension of HiPPO. Its key contribution is the DPLR decomposition of the HiPPO matrix, enabling efficient parallel training via FFT. While modern SSMs like Mamba often simplify A to be strictly diagonal, the mathematical insights of S4 remain invaluable for understanding memory in linear systems.
Summary
This article introduced S4, focusing on the "Diagonal + Low-Rank" decomposition that enables efficient computation. We explored the mathematical details of discretization, convolution, and numerical stability that define this influential work.
Reprinting: Please include the original address: https://kexue.fm/archives/10162