English (unofficial) translations of posts at kexue.fm
Source

Revisiting SSM (II): Some Legacy Issues of HiPPO

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

Continuing from the previous post, in "Revisiting SSM (I): Linear Systems and the HiPPO Matrix", we discussed in detail the derivation of the HiPPO matrix within the HiPPO approximation framework. The principle is to dynamically approximate a real-time updated function through orthogonal function bases. The dynamics of the projection coefficients happen to be a linear system. If orthogonal polynomials are used as the basis, the core matrix of the linear system can be solved analytically; this matrix is called the HiPPO matrix.

Of course, the previous article focused on the derivation of the HiPPO matrix and did not further analyze its properties. Additionally, issues such as "how to discretize for application to actual data" and "whether other bases besides polynomial bases can be solved analytically" were not discussed in detail. Next, we will supplement the discussion on these related issues.

Discrete Formats

Assuming the reader has read and understood the content of the previous article, we will not provide excessive background here. In the previous article, we derived two types of linear ODE systems: \begin{align} &\text{HiPPO-LegT:}\quad x'(t) = Ax(t) + Bu(t) \label{eq:legt-ode}\\[5pt] &\text{HiPPO-LegS:}\quad x'(t) = \frac{A}{t}x(t) + \frac{B}{t}u(t) \label{eq:legs-ode} \end{align} where A and B are constant matrices independent of time t, and the HiPPO matrix primarily refers to matrix A. In this section, we discuss the discretization of these two ODEs.

Input Transformation

In practical scenarios, the input data points are a discrete sequence u_0, u_1, u_2, \dots, u_k, \dots, such as streaming audio signals or text vectors. We hope to use the above ODE systems to memorize these discrete points in real-time. To this end, we first define: \begin{equation} u(t) = u_k, \quad \text{if } t \in [k\epsilon, (k + 1)\epsilon) \end{equation} where \epsilon is the discretization step size. This definition means that within the interval [k\epsilon, (k + 1)\epsilon), u(t) is a constant function equal to u_k. Obviously, u(t) defined this way preserves the information of the original u_k sequence without loss; therefore, memorizing u(t) is equivalent to memorizing the u_k sequence.

Transforming from u_k to u(t) allows the input signal to return to a function on a continuous interval, facilitating subsequent operations like integration. Furthermore, remaining constant within the discretization interval simplifies the discretized format.

LegT Version

We first take the LegT-type ODE [eq:legt-ode] as an example and integrate it on both sides: \begin{equation} x(t+\epsilon) - x(t) = A\int_t^{t+\epsilon} x(s)ds + B\int_t^{t+\epsilon}u(s)ds \end{equation} where t=k\epsilon. According to the definition of u(t), it is constant u_k in the interval [t, t + \epsilon), so the integral of u(s) can be calculated directly: \begin{equation} x(t+\epsilon) - x(t) = A\int_t^{t+\epsilon} x(s)ds + \epsilon B u_k \end{equation} The subsequent result depends on how we approximate the integral of x(s). If we assume x(s) is approximately equal to x(t) in the interval [t, t + \epsilon), we get the Forward Euler format: \begin{equation} x(t+\epsilon) - x(t) = \epsilon A x(t) + \epsilon B u_k \quad\Rightarrow\quad x(t+\epsilon) = (I + \epsilon A)x(t) + \epsilon B u_k \end{equation} If we assume x(s) is approximately equal to x(t+\epsilon) in the interval [t, t + \epsilon), we get the Backward Euler format: \begin{equation} x(t+\epsilon) - x(t) = \epsilon A x(t+\epsilon) + \epsilon B u_k \quad\Rightarrow\quad x(t+\epsilon) = (I - \epsilon A)^{-1}(x(t) + \epsilon B u_k) \end{equation} Both Forward and Backward Euler have the same theoretical accuracy, but Backward Euler usually has better numerical stability. To be more accurate, if we assume x(s) is approximately equal to \frac{1}{2}[x(t) + x(t+\epsilon)] in the interval [t, t + \epsilon), we obtain the bilinear form: \begin{equation} \begin{gathered} x(t+\epsilon) - x(t) = \frac{1}{2}\epsilon A [x(t) + x(t+\epsilon)] + \epsilon B u_k \\ \Downarrow \\ x(t+\epsilon) = (I - \epsilon A/2)^{-1}[(I + \epsilon A/2) x(t) + \epsilon B u_k] \end{gathered} \end{equation} This is also equivalent to taking a half-step with Forward Euler and then a half-step with Backward Euler. More generally, we can assume x(s) \approx \alpha x(t) + (1 - \alpha) x(t+\epsilon) where \alpha \in [0,1], but we won’t expand on that. In fact, we can solve it exactly without approximation because, combined with equation [eq:legt-ode] and the fact that u(s) is constant u_k in [t, t+\epsilon), we can use the "variation of parameters" method to solve it precisely: \begin{equation} x(t+\epsilon) = e^{\epsilon A} x(t) + A^{-1} (e^{\epsilon A} - I) B u_k \label{eq:legt-ode-sol} \end{equation} The matrix exponential here is defined by its series; refer to "Appreciation of the Identity det(exp(A)) = exp(Tr(A))".

LegS Version

Now for the LegS-type ODE. The logic is basically consistent with the LegT type, and the results are similar. First, integrating equation [eq:legs-ode] on both sides gives: \begin{equation} x(t+\epsilon) - x(t) = A\int_t^{t+\epsilon} \frac{x(s)}{s}ds + B\int_t^{t+\epsilon}\frac{u(s)}{s}ds \end{equation} According to the definition of u(t), u(s) in the second integral is constant u_k in [t, t+\epsilon), so it is equivalent to the integral of 1/s, which yields \ln\frac{t+\epsilon}{t}. Of course, replacing it with the first-order approximation \frac{\epsilon}{t} is also fine, as the transformation from u_k to u(t) has significant freedom, and this error is negligible. For the first integral, we use the higher-precision midpoint approximation: \begin{equation} \begin{gathered} x(t+\epsilon) - x(t) = \frac{1}{2}\epsilon A\left(\frac{x(t)}{t}+\frac{x(t+\epsilon)}{t+\epsilon}\right) + \frac{\epsilon}{t} B u_k \\[5pt] \Downarrow \\[5pt] x(t+\epsilon) = \left(I - \frac{\epsilon A}{2(t+\epsilon)}\right)^{-1}\left[\left(I + \frac{\epsilon A}{2t}\right)x(t) + \frac{\epsilon}{t} B u_k\right] \end{gathered} \label{eq:legs-ode-bilinear} \end{equation} In fact, equation [eq:legs-ode] can also be solved exactly by noting it is equivalent to: \begin{equation} Ax(t) + Bu(t) = t x'(t) = \frac{d}{d\ln t} x(t) \end{equation} This means by performing a variable substitution \tau = \ln t, the LegS-type ODE can be converted into a LegT-type ODE: \begin{equation} \frac{d}{d\tau} x(e^{\tau}) = Ax(e^{\tau}) + Bu(e^{\tau}) \end{equation} Using equation [eq:legt-ode-sol] (due to the variable substitution, the time interval changes from \epsilon to \ln(t+\epsilon) - \ln t): \begin{equation} x(t+\epsilon) = e^{(\ln(t+\epsilon) - \ln t) A} x(t) + A^{-1} \big(e^{(\ln(t+\epsilon) - \ln t) A} - I\big) B u_k \label{eq:legs-ode-sol} \end{equation} However, although the above is an exact solution, it is not as useful as the exact solution in equation [eq:legt-ode-sol]. In [eq:legt-ode-sol], the matrix exponential part is e^{\epsilon A}, which is independent of time t and can be computed once. But in the above equation, t is inside the matrix exponential, meaning the matrix exponential must be repeatedly calculated during iteration, which is not computationally friendly. Therefore, for LegS-type ODEs, we generally only use equation [eq:legs-ode-bilinear] for discretization.

Excellent Properties

Next, LegS is our primary focus. The reason for focusing on LegS is not hard to guess: based on the derivation assumptions, it is currently the only solved ODE system capable of memorizing the entire history, which is crucial for many scenarios like multi-turn dialogue. Additionally, it has other good and practical properties.

Scale Equivariance

For example, the discretization format [eq:legs-ode-bilinear] of LegS is step-size independent. By substituting t=k\epsilon into it and denoting x(k\epsilon)=x_k, we find: \begin{equation} x_{k+1} = \left(I - \frac{A}{2(k + 1)}\right)^{-1}\left[\left(I + \frac{A}{2k}\right)x_k + \frac{1}{k} B u_k\right] \end{equation} The step size \epsilon is automatically eliminated, naturally reducing a hyperparameter that needs tuning, which is good news for practitioners. Note that step-size independence is an inherent property of LegS-type ODEs and is not directly related to the specific discretization method. For instance, the exact solution [eq:legs-ode-sol] is also step-size independent: \begin{equation} x_{k+1} = e^{(\ln(k+1) - \ln k) A} x_k + A^{-1} \big(e^{(\ln(k+1) - \ln k) A} - I\big) B u_k \label{eq:legs-ode-sol-2} \end{equation} The underlying reason is that LegS-type ODEs satisfy "Timescale equivariance". If we substitute t=\lambda\tau into the LegS-type ODE, we get: \begin{equation} Ax(\lambda\tau) + Bu(\lambda\tau) = (\lambda\tau) \times \frac{d}{d(\lambda\tau)} x(\lambda\tau) = \tau \frac{d}{d\tau}x(\lambda\tau) \end{equation} This means when we replace u(t) with u(\lambda t), the form of the LegS ODE does not change, and the corresponding solution is x(t) replaced by x(\lambda t). The direct consequence of this property is that when we choose a larger step size, the recursive format does not need to change because the step size of the result x_k will also automatically scale. This is the fundamental reason why LegS-type ODE discretization is step-size independent.

Polynomial Decay

Another excellent property of LegS-type ODEs is that the memory of historical signals exhibits Polynomial decay, which is slower than the exponential decay of conventional RNNs. Theoretically, this allows for memorizing longer histories and makes it less prone to vanishing gradients. To understand this, we can start from the exact solution [eq:legs-ode-sol-2]. From [eq:legs-ode-sol-2], we see that for each recursion step, the decay effect of historical information can be described by the matrix exponential e^{(\ln(k+1) - \ln k) A}. Then, from step m to step n, the total decay effect is: \begin{equation} \prod_{k=m}^{n-1} e^{(\ln(k+1) - \ln k) A} = e^{(\ln n - \ln m) A} \end{equation} Recall the form of A in HiPPO-LegS: \begin{equation} A_{n,k} = -\left\{\begin{array}{ll}\sqrt{(2n+1)(2k+1)}, & k < n \\ n+1, & k = n \\ 0, & k > n\end{array}\right. \end{equation} From the definition, A is a lower triangular matrix with diagonal elements -1, -2, -3, \dots. We know that the diagonal elements of a triangular matrix are exactly its eigenvalues. Thus, a d \times d matrix A has d distinct eigenvalues -1, -2, \dots, -d. This implies A is diagonalizable, i.e., there exists an invertible matrix P such that A = P^{-1}\Lambda P, where \Lambda = \text{diag}(-1, -2, \dots, -d). Thus, we have: \begin{equation} \begin{aligned} e^{(\ln n - \ln m) A} &= e^{(\ln n - \ln m) P^{-1}\Lambda P} \\ &= P^{-1} e^{(\ln n - \ln m) \Lambda}P \\ &= P^{-1}\,\text{diag}(e^{-(\ln n - \ln m)}, e^{-2(\ln n - \ln m)}, \dots, e^{-d(\ln n - \ln m)})\,P \\ &= P^{-1}\,\text{diag}\Big(\frac{m}{n}, \frac{m^2}{n^2}, \dots, \frac{m^d}{n^d}\Big)\,P \end{aligned} \end{equation} As seen, the final decay function is a linear combination of 1/n raised to powers 1, 2, \dots, d. Therefore, the historical memory of LegS-type ODEs is at most polynomial decay, which is more long-tailed than exponential decay, theoretically providing better memory capacity.

Computational Efficiency

Finally, we point out that the A matrix of HiPPO-LegS is computationally efficient. Specifically, a naive implementation of a d \times d matrix multiplied by a d \times 1 column vector requires d^2 multiplications. However, the multiplication of the LegS A matrix with a vector can be reduced to \mathcal{O}(d). Furthermore, we can prove that the discretization in [eq:legs-ode-bilinear] can also be completed in \mathcal{O}(d).

To understand this, we first rewrite the HiPPO-LegS A matrix equivalently as: \begin{equation} A_{n,k} = \left\{\begin{array}{ll} n\delta_{n,k} - \sqrt{2n+1}\sqrt{2k+1}, & k \leq n \\ 0, & k > n\end{array}\right. \end{equation} For a vector v = [v_0, v_1, \dots, v_{d-1}], we have: \begin{equation} \begin{aligned} (Av)_n = \sum_{k=0}^n A_{n,k}v_k &= \sum_{k=0}^n \left(n\delta_{n,k} - \sqrt{2n+1}\sqrt{2k+1}\right)v_k \\ &= n v_n - \sqrt{2n+1}\sum_{k=0}^n \sqrt{2k+1}v_k \end{aligned} \end{equation} This involves three operations: the first term n v_n is an element-wise multiplication of the vector [0, 1, 2, \dots, d-1] and v; the second term \sqrt{2k+1}v_k is an element-wise multiplication of [1, \sqrt{3}, \sqrt{5}, \dots, \sqrt{2d-1}] and v, followed by a \text{cumsum} operation, and finally multiplying by \sqrt{2n+1} (another element-wise multiplication). Each step can be completed in \mathcal{O}(d), so the total complexity is \mathcal{O}(d).

Now let’s look at [eq:legs-ode-bilinear]. It contains two "matrix-vector" multiplications. One is (I+\lambda A)v, where \lambda is any real number; we just proved Av is efficient, so (I+\lambda A)v is as well. The second is (I-\lambda A)^{-1}v. We will prove this is also efficient. Solving z=(I-\lambda A)^{-1}v is equivalent to solving the equation v = (I-\lambda A)z. Using the expression for Az given above, we get: \begin{equation} v_n = z_n - \lambda \left(n z_n - \sqrt{2n+1}\sum_{k=0}^n \sqrt{2k+1}z_k\right) \end{equation} Let S_n = \sum_{k=0}^n \sqrt{2k+1}z_k, then z_n = \frac{S_n - S_{n-1}}{\sqrt{2n+1}}. Substituting this into the equation: \begin{equation} v_n = \frac{S_n - S_{n-1}}{\sqrt{2n+1}} - \lambda \left(n \frac{S_n - S_{n-1}}{\sqrt{2n+1}} - \sqrt{2n+1}S_n\right) \end{equation} Rearranging gives: \begin{equation} S_n = \frac{1 - \lambda n}{1+\lambda n + \lambda}S_{n-1} + \frac{\sqrt{2n+1}}{1+\lambda n + \lambda}v_n \end{equation} This is a scalar recursion that can be calculated serially or in parallel using Prefix Sum algorithms, with a complexity of \mathcal{O}(d) or \mathcal{O}(d \log d), which is much more efficient than \mathcal{O}(d^2).

Fourier Basis

Finally, we conclude with a derivation using the Fourier basis. In the previous article, we introduced linear systems using Fourier series but only derived results for the local window form. For the Legendre polynomial basis, we derived both local window and full interval versions (LegT and LegS). Can the Fourier basis also yield a version equivalent to LegS? What difficulties are involved? We explore this below.

Again, we skip the background. Following the notation from the previous section, the Fourier basis coefficients are: \begin{equation} c_n(T) = \int_0^1 u(t_{\leq T}(s)) e^{-2i\pi n s}ds \end{equation} Like LegS, to memorize the signal over the entire [0, T] interval, we need a mapping [0, 1] \mapsto [0, T]. We choose the simplest t_{\leq T}(s) = sT. Differentiating both sides with respect to T: \begin{equation} \frac{d}{dT}c_n(T) = \int_0^1 u'(sT) s e^{-2i\pi n s}ds \end{equation} Using integration by parts: \begin{equation} \begin{aligned} \frac{d}{dT}c_n(T) &= \frac{1}{T}\int_0^1 s e^{-2i\pi n s}d u(sT) \\ &= \frac{1}{T} u(sT) s e^{-2i\pi n s}\big|_{s=0}^{s=1} - \frac{1}{T}\int_0^1 u(sT) d(s e^{-2i\pi n s})\\ &= \frac{1}{T} u(T) - \frac{1}{T}\int_0^1 u(sT) e^{-2i\pi n s} ds + \frac{2i\pi n}{T}\int_0^1 u(sT) s e^{-2i\pi n s} ds\\ &= \frac{1}{T} u(T) - \frac{1}{T}c_n(T) + \frac{2i\pi n}{T}\int_0^1 u(sT) s e^{-2i\pi n s} ds\\ \end{aligned} \end{equation} In the previous article, we mentioned that a key reason HiPPO chooses Legendre polynomials is that (s+1)p_n'(t) can be decomposed into a linear combination of p_0(t), p_1(t), \dots, p_n(t), whereas s e^{-2i\pi n s} for the Fourier basis seemingly cannot. However, if error is allowed, this assertion is not strictly true, as we can also decompose s into a Fourier series: \begin{equation} s = \frac{1}{2} + \frac{i}{2\pi}\sum_{k\neq 0} \frac{1}{k} e^{2i\pi k s} \end{equation} This sum has infinite terms. Truncating it to finite terms introduces error, but we can ignore that for now and substitute it back: \begin{equation} \begin{aligned} &\frac{2i\pi n}{T}\int_0^1 u(sT) s e^{-2i\pi n s} ds \\ &= \frac{2i\pi n}{T}\int_0^1 u(sT) \left(\frac{1}{2} + \frac{i}{2\pi}\sum_{k\neq 0} \frac{1}{k} e^{2i\pi k s}\right) e^{-2i\pi n s} ds \\ &= \frac{i\pi n}{T}\int_0^1 u(sT) e^{-2i\pi n s} ds - \frac{1}{T}\sum_{k\neq 0} \frac{n}{k}\int_0^1 u(sT) e^{-2i\pi (n - k) s} ds \\ &= \frac{i\pi n}{T}c_n(T) - \frac{1}{T}\sum_{k\neq 0} \frac{n}{k}c_{n-k}(T) \\ &= \frac{i\pi n}{T}c_n(T) - \frac{1}{T}\sum_{k\neq n} \frac{n}{n - k}c_k(T) \\ \end{aligned} \end{equation} Thus: \begin{equation} \frac{d}{dT}c_n(T) = \frac{1}{T} u(T) + \frac{i\pi n - 1}{T}c_n(T) - \frac{1}{T}\sum_{k\neq n} \frac{n}{n - k}c_k(T) \end{equation} So we can write: \begin{equation} \begin{aligned} x'(t) &= \frac{A}{t}x(t) + \frac{B}{t}u(t)\\[8pt] A_{n,k} &= \left\{\begin{array}{ll}-\frac{n}{n-k}, & k \neq n \\ i\pi n - 1, & k = n\end{array}\right.\\[8pt] B_n &= 1 \end{aligned} \end{equation} In practice, we only need to truncate |n|, |k| \leq N to get a (2N+1) \times (2N+1) matrix. The error from truncation is generally acceptable, just as we introduced finite series approximations in HiPPO-LegT. For specific tasks, we choose an appropriate scale N such that the truncation error is negligible.

For most people, this Fourier basis derivation might be easier to understand because Legendre polynomials are unfamiliar to many, especially the identities used in LegT and LegS. Most readers have some knowledge of Fourier series. However, the Fourier result might be less practical than LegS: first, it introduces complex numbers, increasing implementation complexity; second, the resulting A matrix is not a simple lower triangular matrix like LegS, making theoretical analysis more complex. Thus, consider this an exercise to deepen your understanding of HiPPO.

Summary

In this article, we supplemented the discussion on some legacy issues of HiPPO, including how to discretize ODEs, the excellent properties of LegS-type ODEs, and the derivation of the Fourier version of LegS. We hope this provides a more comprehensive understanding of HiPPO.

Original URL: https://kexue.fm/archives/10137

For more details on reposting, please refer to: "Scientific Space FAQ"