English (unofficial) translations of posts at kexue.fm
Source

Revisiting SSM (I): Linear Systems and HiPPO Matrices

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

A few days ago, I read several articles introducing State Space Models (SSM) and realized that I had never seriously studied them. Consequently, I decided to learn about SSM-related content properly and, in the process, started this new series to record what I have learned.

The concept of SSM has been around for a long time, but here we specifically refer to SSM in the context of deep learning. It is generally considered that the seminal work is the 2021 paper S4, which is not too old. The most recent and popular variant of SSM is likely last year’s Mamba. Of course, when we talk about SSM, it can also refer broadly to all linear RNN models; thus, RWKV, RetNet, and the LRU we introduced previously in "Google’s New Work Attempts to ’Revive’ RNNs: Can RNNs Shine Again?" can all be classified into this category. Many SSM variants aim to become competitors to the Transformer. Although I do not believe there is a possibility of complete replacement, the elegant mathematical properties of SSM themselves are worth studying.

Although we say SSM originated with S4, before S4, there was a very powerful foundational work titled "HiPPO: Recurrent Memory with Optimal Polynomial Projections" (referred to as HiPPO). Therefore, this article begins with HiPPO.

Basic Form

As a side note, the lead author of the representative SSM works mentioned above—HiPPO, S4, and Mamba—is Albert Gu. He has many other works related to SSM. It is no exaggeration to say that these works have built the foundation of the SSM edifice. Regardless of the future of SSM, this spirit of persistent research on the same topic deserves our sincere respect.

Returning to the main topic. For readers who already have some understanding of SSM, you likely know that SSM modeling uses a linear ODE (Ordinary Differential Equation) system: \begin{equation} \begin{aligned} x'(t) &= A x(t) + B u(t) \\ y(t) &= C x(t) + D u(t) \end{aligned} \label{eq:ode} \end{equation} where u(t)\in\mathbb{R}^{d_i}, x(t)\in\mathbb{R}^{d}, y(t)\in\mathbb{R}^{d_o}, A\in\mathbb{R}^{d\times d}, B\in\mathbb{R}^{d\times d_i}, C\in\mathbb{R}^{d_o\times d}, D\in\mathbb{R}^{d_o\times d_i}. Of course, we can also discretize it, turning it into a linear RNN model, which we will expand upon in later articles. Regardless of whether it is discretized or not, the keyword is “linear.” This immediately raises a natural question: Why a linear system? Is a linear system sufficient?

We can answer this question from two perspectives: linear systems are both sufficiently simple and sufficiently complex. Simple means that, theoretically, linearization is often the most basic approximation of a complex system, so linear systems are usually an unavoidable fundamental point. Complex means that even such a simple system can fit exceptionally complex functions. To understand this, we only need to consider a simple example in \mathbb{R}^4: \begin{equation} x'(t) = \begin{pmatrix} 1 & 0 & 0 & 0 \\ 0 & -1 & 0 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & 0 & -1 & 0 \end{pmatrix}x(t) \end{equation} The fundamental solution to this example is x(t) = (e^t, e^{-t}, \sin t, \cos t). What does this mean? It means that as long as d is large enough, the linear system can fit sufficiently complex functions through combinations of exponential and trigonometric functions. We know that the highly capable Fourier series is just a combination of trigonometric functions; adding exponential functions clearly makes it even stronger. Therefore, one can imagine that linear systems have sufficiently complex fitting capabilities.

Of course, these explanations are, in a sense, “hindsight.” The results given by HiPPO are more fundamental: when we attempt to approximate a dynamically updated function using orthogonal bases, the result is a linear system as shown above. This means that HiPPO not only tells us that linear systems can approximate sufficiently complex functions but also tells us how to approximate them and even the degree of approximation.

Finite Compression

Next, we only consider the special case where d_i=1; d_i > 1 is simply a parallel generalization of the d_i=1 case. In this case, the output of u(t) is a scalar. Furthermore, as a starting point, let us assume t\in[0, 1]. The goal of HiPPO is: to use a finite-dimensional vector to store the information of this segment of u(t).

This seems like an impossible requirement because t\in[0,1] implies that u(t) might be equivalent to a vector composed of an infinite number of points, and compressing it into a finite-dimensional vector might cause severe distortion. However, if we make some assumptions about u(t) and allow for some loss, this compression is possible, and most readers have already tried it. For example, if u(t) is (n+1)-th order differentiable at a certain point, its corresponding n-th order Taylor expansion is often a good approximation of u(t). Thus, we can store only the n+1 coefficients of the expansion as an approximate representation of u(t), successfully compressing u(t) into an (n+1)-dimensional vector.

Of course, for data encountered in practice, a condition like “(n+1)-th order differentiable” is extremely harsh. We usually prefer to use orthogonal function basis expansions under square-integrable conditions, such as the Fourier series. The formula for calculating its coefficients is: \begin{equation} c_n = \int_0^1 u(t) e^{-2i\pi n t}dt \label{eq:fourier-coef-1} \end{equation} By choosing a sufficiently large integer N and retaining only the coefficients for |n|\leq N, we compress u(t) into a (2N + 1)-dimensional vector.

Next, the difficulty level increases. We just said t\in[0,1], which is a static interval. In practice, u(t) represents a continuously collected signal, so new data is constantly entering. For example, if we have approximated the data in the interval [0,1], data for [1,2] comes in immediately. You need to update the approximation result to try to remember the entire [0,2] interval, then [0,3], [0,4], and so on. This is called “online function approximation.” The Fourier coefficient formula [eq:fourier-coef-1] above only applies to the interval [0,1], so it needs to be generalized.

To this end, let t\in[0,T], and let s\mapsto t_{\leq T}(s) be a mapping from [0,1] to [0,T]. Then, when u(t_{\leq T}(s)) is viewed as a function of s, its domain is [0,1], allowing us to reuse formula [eq:fourier-coef-1]: \begin{equation} c_n(T) = \int_0^1 u(t_{\leq T}(s)) e^{-2i\pi n s}ds \label{eq:fourier-coef-2} \end{equation} Here, we have added the notation (T) to the coefficients to indicate that they change as T changes.

Emergence of Linearity

There are infinitely many functions that can map [0,1] to [0,T], and the final result varies depending on t_{\leq T}(s). Some relatively intuitive and simple choices are as follows:

1. t_{\leq T}(s) = sT, which maps [0,1] uniformly to [0,T];

2. Note that t_{\leq T}(s) does not have to be surjective, so something like t_{\leq T}(s)=s + T - 1 is also allowed. This means only the information in the most recent window [T-1, T] is retained, and earlier parts are discarded. More generally, t_{\leq T}(s)=sw + T - w, where w is a constant, meaning information before T-w is discarded;

3. One can also choose a non-uniform mapping, such as t_{\leq T}(s) = T\sqrt{s}. This is also a surjective mapping from [0,1] to [0,T], but when s=1/4, it maps to T/2. This means that while we focus on the global history, we simultaneously place more emphasis on information near time T.

Now, taking t_{\leq T}(s)=sw + T - w as an example, substituting it into formula [eq:fourier-coef-2] gives: \begin{equation} c_n(T) = \int_0^1 u(sw + T - w) e^{-2i\pi n s}ds \end{equation} Now we take the derivative of both sides with respect to T: \begin{equation} \begin{aligned} \frac{d}{dT}c_n(T) &= \int_0^1 u'(sw + T - w) e^{-2i\pi n s}ds \\ &= \left.\frac{1}{w} u(sw + T - w) e^{-2i\pi n s}\right|_{s=0}^{s=1} + \frac{2i\pi n}{w}\int_0^1 u(sw + T - w) e^{-2i\pi n s}ds \\ &= \frac{1}{w} u(T) - \frac{1}{w} u(T-w) + \frac{2i\pi n}{w} c_n(T) \\ \end{aligned} \label{eq:fourier-dc} \end{equation} where the second equality uses integration by parts. Since we only retain coefficients for |n|\leq N, according to the Fourier series formula, we can consider the following as a good approximation of u(sw + T - w): \begin{equation} u(sw + T - w) \approx \sum_{k=-N}^{k=N} c_k(T) e^{2i\pi k s} \end{equation} Then u(T - w) = u(sw + T - w)|_{s=0} \approx \sum_{k=-N}^{k=N} c_k(T). Substituting this into formula [eq:fourier-dc] yields: \begin{equation} \frac{d}{dT}c_n(T) \approx \frac{1}{w} u(T) - \frac{1}{w} \sum_{k=-N}^{k=N} c_k(T) + \frac{2i\pi n}{w} c_n(T) \end{equation} Replacing T with t and stacking all c_n(t) into a vector x(t) = (c_{-N}, c_{-(N-1)}, \dots, c_0, \dots, c_{N-1}, c_N), and treating \approx as =, we can write: \begin{equation} x'(t) = Ax(t) + Bu(t), \quad A_{n,k} = \begin{cases} (2i\pi n - 1)/w, & k=n \\ -1/w, & k \neq n \end{cases}, \quad B_n = 1/w \end{equation} This results in a linear ODE system as shown in formula [eq:ode]. That is, when we attempt to use a Fourier series to remember the state within the most recent window of a real-time function, a linear ODE system naturally emerges.

General Framework

Of course, we have only chosen a specific t_{\leq T}(s). Choosing a different t_{\leq T}(s) might not yield such a simple result. Furthermore, the Fourier series conclusion is in the complex domain; while it can be realified, the form becomes complicated. Therefore, we need to generalize the process from the previous section into a general framework to obtain more general and simpler pure real-number conclusions.

Let t\in[a,b], and given a target function u(t) and a function basis \{g_n(t)\}_{n=0}^N, we want to approximate the former with a linear combination of the latter. The goal is to minimize the L_2 distance: \begin{equation} \mathop{\text{argmin}}_{c_1,\dots,c_N} \int_a^b \left[u(t) - \sum_{n=0}^N c_n g_n(t)\right]^2 dt \end{equation} Here we primarily consider the real domain, so the square is sufficient. A more generalized objective function could include a weight function \rho(t), but we will not consider that here, as the main conclusions of HiPPO do not rely on it.

Expanding the objective function, we get: \begin{equation} \int_a^b u^2(t) dt - 2\sum_{n=0}^N c_n \int_a^b u(t) g_n(t)dt + \sum_{m=0}^N\sum_{n=0}^N c_m c_n \int_a^b g_m(t) g_n(t) dt \end{equation} We only consider orthonormal function bases, defined by \int_a^b g_m(t) g_n(t) dt = \delta_{m,n}, where \delta_{m,n} is the Kronecker delta function. In this case, the expression simplifies to: \begin{equation} \int_a^b u^2(t) dt - 2\sum_{n=0}^N c_n \int_a^b u(t) g_n(t)dt + \sum_{n=0}^N c_n^2 \end{equation} This is a quadratic function of c_n, and its minimum has an analytical solution: \begin{equation} c^*_n = \int_a^b u(t) g_n(t)dt \end{equation} This is also known as the inner product of u(t) and g_n(t), a parallel generalization of the inner product in finite-dimensional vector spaces to function spaces. For simplicity, we assume c_n refers to c^*_n.

The subsequent processing is the same as in the previous section. We consider the approximation of u(t) for general t\in[0, T]. We find a mapping s\mapsto t_{\leq T}(s) from [a,b] to [0,T] and calculate the coefficients: \begin{equation} c_n(T) = \int_a^b u(t_{\leq T}(s)) g_n(s) ds \end{equation} Taking the derivative with respect to T and using integration by parts: \begin{equation} \scriptsize \begin{aligned} \frac{d}{dT}c_n(T) &= \int_a^b u'(t_{\leq T}(s)) \frac{\partial t_{\leq T}(s)}{\partial T} g_n(s) ds = \int_a^b \left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s) d u(t_{\leq T}(s)) \\ &= \left.u(t_{\leq T}(s))\left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s)\right|_{s=a}^{s=b} - \int_a^b u(t_{\leq T}(s)) \,d\left[\left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s)\right] \end{aligned} \label{eq:hippo-base} \end{equation}

Enter Legendre

The subsequent calculations depend on the specific forms of g_n(t) and t_{\leq T}(s). HiPPO stands for High-order Polynomial Projection Operators, where the first ’P’ stands for Polynomial. Thus, the key to HiPPO is selecting polynomials as the basis. We now introduce the Legendre polynomials.

The Legendre polynomial p_n(t) is an n-th degree function of t, defined on [-1,1], satisfying: \begin{equation} \int_{-1}^1 p_m(t) p_n(t) dt = \frac{2}{2n+1}\delta_{m,n} \end{equation} Thus, p_n(t) are orthogonal but not orthonormal. The orthonormal basis is g_n(t) = \sqrt{\frac{2n+1}{2}} p_n(t).

When performing Gram-Schmidt orthonormalization on the basis \{1, t, t^2, \dots, t^n\}, the result is precisely the Legendre polynomials. Compared to the Fourier basis, Legendre polynomials have the advantage of being purely defined in real space, and their polynomial form helps simplify the derivation for certain t_{\leq T}(s).

We use two recurrence formulas to derive an identity: \begin{align} p_{n+1}'(t) - p_{n-1}'(t) &= (2n+1)p_n(t) \label{eq:leg-r1} \\[5pt] p_{n+1}'(t) &= (n + 1)p_n(t) + t p_n'(t) \label{eq:leg-r2} \end{align} Iterating the first formula [eq:leg-r1] gives: \begin{equation} \begin{aligned} p_{n+1}'(t) &= (2n+1)p_n(t) + (2n-3)p_{n-2}(t) + (2n-7)p_{n-4}(t) + \dots \\ &= \sum_{k=0}^n (2k+1) \chi_{n-k} p_k(t) \end{aligned} \label{eq:leg-dot} \end{equation} where \chi_k=1 if k is even and \chi_k=0 otherwise. Substituting this into the second formula [eq:leg-r2] yields: \begin{equation} t p_n'(t) = n p_n(t) + (2n-3)p_{n-2}(t) + (2n-7)p_{n-4}(t) + \dots \end{equation} Consequently: \begin{equation} \begin{aligned} (t+1) p_n'(t) &= n p_n(t) + (2n-1)p_{n-1}(t) + (2n-3)p_{n-2}(t) + \dots \\ &= -(n+1) p_n(t) + \sum_{k=0}^n (2k + 1) p_k(t) \end{aligned} \label{eq:leg-dot-t1} \end{equation} These are the identities we will use. Additionally, Legendre polynomials satisfy p_n(1)=1 and p_n(-1)=(-1)^n.

Neighborhood Window

We now substitute specific t_{\leq T}(s) for calculation. As a first example, consider retaining only information in the most recent window: t_{\leq T}(s) = (s + 1)w / 2 + T - w, which maps [-1,1] to [T-w, T]. This is called LegT (Translated Legendre).

Substituting into formula [eq:hippo-base] gives: \begin{equation} \small \frac{d}{dT}c_n(T) = \frac{\sqrt{2(2n+1)}}{w}\left[u(T) - (-1)^n u(T-w)\right] - \frac{2}{w}\int_{-1}^1 u((s + 1)w / 2 + T - w) g_n'(s) ds \end{equation} Approximating u with the truncated basis: \begin{equation} u((s + 1)w / 2 + T - w) \approx \sum_{k=0}^N c_k(T)g_k(s) \end{equation} Thus u(T-w) \approx \sum_{k=0}^N (-1)^k c_k(T) \sqrt{\frac{2k+1}{2}}. Using formula [eq:leg-dot]: \begin{equation} \begin{aligned} &\int_{-1}^1 u((s + 1)w / 2 + T - w) g_n'(s) ds \\ &= \sqrt{2n+1}\sum_{k=0}^{n-1} \sqrt{2k+1} \chi_{n-1-k} c_k(T) \end{aligned} \end{equation} Integrating these results: \begin{equation} \begin{aligned} \frac{d}{dT}c_n(T) \approx & \frac{\sqrt{2(2n+1)}}{w}u(T) - \frac{\sqrt{2n+1}}{w} \sum_{k=n}^N (-1)^{n-k} c_k(T) \sqrt{2k+1} \\ & - \frac{\sqrt{2n+1}}{w}\sum_{k=0}^{n-1} \sqrt{2k+1} \underbrace{\left(2\chi_{n-1-k} + (-1)^{n-k}\right)}_{\equiv 1}c_k(T) \end{aligned} \label{eq:leg-t} \end{equation} Stacking c_n(t) into x(t) = (c_0, \dots, c_N): \begin{equation} \begin{aligned} x'(t) &= Ax(t) + Bu(t) \\ A_{n,k} &= -\frac{1}{w}\begin{cases} \sqrt{(2n+1)(2k+1)}, & k < n \\ (-1)^{n-k}\sqrt{(2n+1)(2k+1)}, & k \geq n \end{cases} \\ B_n &= \frac{1}{w}\sqrt{2(2n+1)} \end{aligned} \label{eq:leg-t-hippo-1} \end{equation} By choosing different scaling factors \lambda_n, one can align with results from the original paper or Legendre Memory Units. For instance, with \lambda_n = \frac{2}{\sqrt{2n+1}}: \begin{equation} A_{n,k} = -\frac{1}{w}\begin{cases} 2n+1, & k < n \\ (-1)^{n-k}(2n+1), & k \geq n \end{cases}, \quad B_n = \frac{1}{w}(2n+1) \end{equation}

Entire Interval

Now consider t_{\leq T}(s) = (s + 1)T / 2, mapping [-1,1] uniformly to [0,T]. This is LegS (Scaled Legendre).

Substituting into [eq:hippo-base]: \begin{equation} \frac{d}{dT}c_n(T) = \frac{\sqrt{2(2n+1)}}{T}u(T) - \frac{1}{T}\int_{-1}^1 u((s + 1)T / 2) \left[g_n(s) + (s+1) g_n'(s)\right] ds \end{equation} Using formula [eq:leg-dot-t1]: \begin{equation} \begin{aligned} &\int_{-1}^1 u((s + 1)T / 2) \left[g_n(s) + (s+1) g_n'(s)\right] ds \\ &= -n c_n(T) + \sum_{k=0}^n \sqrt{(2n+1)(2k + 1)} c_k(T) \end{aligned} \end{equation} Thus: \begin{equation} \frac{d}{dT}c_n(T) = \frac{\sqrt{2(2n+1)}}{T}u(T) - \frac{1}{T}\left(-n c_n(T) + \sum_{k=0}^n \sqrt{(2n+1)(2k + 1)} c_k(T)\right) \end{equation} In vector form: \begin{equation} \begin{aligned} x'(t) &= \frac{A}{t}x(t) + \frac{B}{t}u(t) \\ A_{n,k} &= -\begin{cases} \sqrt{(2n+1)(2k+1)}, & k < n \\ n+1, & k = n \\ 0, & k > n \end{cases} \\ B_n &= \sqrt{2(2n+1)} \end{aligned} \label{eq:leg-s-hippo} \end{equation}

Extended Thinking

Reflecting on the derivation of LegS, the key step was decomposing (s+1)g_n'(s) into a linear combination of g_0, \dots, g_n. For orthogonal polynomials, this decomposition is exact. For Fourier series, this cannot be done exactly, which is why orthogonal polynomials are chosen to simplify the derivation.

HiPPO is a bottom-up framework. It doesn’t assume linearity from the start but derives it from the perspective of orthogonal basis approximation. This gives us confidence that linear ODE systems are sufficient for fitting complex functions.

However, there is a trade-off between resolution and memory length. Since the dimension of x(t) is fixed, fitting a function over a larger interval inevitably leads to lower resolution. To maintain performance over longer inputs, one must increase the model size (hidden size), a characteristic of all linear systems.

Summary

This article has repeated the main derivations of the HiPPO paper. By making appropriate memory assumptions, HiPPO derives linear ODE systems from the bottom up and provides analytical solutions (HiPPO matrices) for Legendre polynomials. These results serve as a foundational pillar for many subsequent State Space Models.

Reprinting should include the original address: https://kexue.fm/archives/10114

For more details on reprinting, please refer to: "Scientific Space FAQ"