English (unofficial) translations of posts at kexue.fm
Source

Mind-Blowing: Can Non-linear RNNs Actually Be Computed in Parallel?

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

In recent years, linear RNNs have attracted significant attention from researchers due to their characteristics such as parallelizable training and constant inference costs (for example, the previous article “Google’s New Work Attempts to ’Revive’ RNNs: Can RNNs Shine Again?”), which has allowed RNNs to maintain a “place at the table” amidst the widespread success of Transformers. However, it currently seems that this “place” belongs only to linear RNNs, as non-linear RNNs cannot be trained efficiently in parallel, making them “willing but unable” in the architecture competition.

However, a paper titled “Parallelizing Non-Linear Sequential Models over the Sequence Length” offers a different perspective. It proposes an iterative algorithm that claims to achieve parallel training for non-linear RNNs! Is it truly that magical? Let’s explore it further.

Finding Fixed Points

The original paper provides a very general introduction to its method, with a focus on PDEs and ODEs. Here, we will start directly with RNNs. Consider a common simple non-linear RNN: \begin{equation} x_t = \tanh(Ax_{t-1} + u_t)\label{eq:rnn} \end{equation} Due to the presence of \tanh, it can only be computed serially. Now, let’s subtract Ax_{t-1} from both sides: \begin{equation} x_t - Ax_{t-1} = \tanh(Ax_{t-1} + u_t) - Ax_{t-1} \end{equation} Of course, this does not change the fact that it is a non-linear RNN. However, we can observe that if the x_{t-1} on the right-hand side were replaced by a given vector like u_t, then this would become a linear RNN. According to the results in “Google’s New Work Attempts to ’Revive’ RNNs: Can RNNs Shine Again?”, it could then be computed in parallel. At this point, astute readers might have already guessed the next step—iterative solving!

First, we modify the above RNN into: \begin{equation} x_t^{(n)} - Ax_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - Ax_{t-1}^{(n-1)}\label{eq:rnn-iter} \end{equation} Starting from a given x_t^{(0)}, we repeatedly iterate the above equation. Ideally, it will converge to a fixed point x_t^*, which is the result of the original non-linear RNN. Theoretically, the total computational cost of iterating through Equation [eq:rnn-iter] is greater than directly calculating recursively through Equation [eq:rnn]. However, since each iteration step is a parallelizable linear RNN, and if the convergence speed is fast enough such that not many iterations are required, the total time consumed is usually faster than direct non-linear RNN recursion (especially when the sequence length is very large).

Simplified Form

In fact, the reason non-linear RNNs are slow is not just because they cannot be computed in parallel; more importantly, they contain a large number of non-element-wise operations, such as the matrix operation Ax_{t-1} inside the \tanh in Equation [eq:rnn]. Linear RNNs are fast not only because they allow parallel training but also because they can transform matrix multiplication into element-wise multiplication through diagonalization—for element-wise multiplication, even serial computation is not too slow.

When we transform a non-linear RNN into an iteration of linear RNNs via Equation [eq:rnn-iter], we can similarly enjoy the “treatment” of linear RNN diagonalization to increase computation speed. Specifically, by diagonalizing A into P\Lambda P^{-1} in the complex field, Equation [eq:rnn-iter] becomes: \begin{equation} x_t^{(n)} - P\Lambda P^{-1} x_{t-1}^{(n)} = \tanh(P\Lambda P^{-1} x_{t-1}^{(n-1)} + u_t) - P\Lambda P^{-1} x_{t-1}^{(n-1)} \end{equation} Multiplying both sides by P^{-1} on the left: \begin{equation} P^{-1} x_t^{(n)} - \Lambda P^{-1} x_{t-1}^{(n)} = P^{-1}\tanh(P\Lambda P^{-1} x_{t-1}^{(n-1)} + u_t) - \Lambda P^{-1} x_{t-1}^{(n-1)} \end{equation} Let y_t = P^{-1} x_t, then the above equation can be simplified to: \begin{equation} y_t^{(n)} - \Lambda y_{t-1}^{(n)} = P^{-1}\tanh(P\Lambda y_{t-1}^{(n-1)} + u_t) - \Lambda y_{t-1}^{(n-1)} \end{equation} Since an RNN is generally followed by a projection layer, the P in x_t = P y_t can, in principle, be merged into the external projection layer. That is to say, the above equation theoretically possesses the same expressive power as the original Equation [eq:rnn]. However, because \Lambda is a diagonal matrix, the computational cost of recursion is significantly reduced. The above equation also involves the inverse matrix P^{-1}, which is not only computationally expensive but also unfavorable for optimization. Therefore, we can simply replace P^{-1} and P\Lambda with two unrelated parameter matrices: \begin{equation} y_t^{(n)} - \Lambda y_{t-1}^{(n)} = P\tanh(Q y_{t-1}^{(n-1)} + u_t) - \Lambda y_{t-1}^{(n-1)} \end{equation} As long as the initialization satisfies PQ=\Lambda.

The Perturbation Idea

Assuming x_t^{(0)}=0, Equation [eq:rnn-iter] essentially decomposes the original non-linear RNN into a series of linear RNNs: \begin{equation} \begin{array}{c} x_t^{(1)} - Ax_{t-1}^{(1)} = \tanh(u_t)\\ x_t^{(2)} - Ax_{t-1}^{(2)} = \tanh(Ax_{t-1}^{(1)} + u_t) - Ax_{t-1}^{(1)} \\ \vdots \\ x_t^{(n)} - Ax_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - Ax_{t-1}^{(n-1)} \\ \vdots \\ \end{array}\label{eq:rnns} \end{equation} If we assume x_{t-1} and u_t are small quantities, then using \tanh x \approx x on the right side of Equation [eq:rnn] gives: \begin{equation} x_t = \tanh(Ax_{t-1} + u_t) \approx Ax_{t-1} + u_t \approx Ax_{t-1} + \tanh(u_t)\label{eq:rnn-approx} \end{equation} This happens to be the first equation in [eq:rnns]. Therefore, if the assumption holds, x_t^{(1)} might already be sufficiently close to the ideal x_t^*, and each subsequent iteration step rapidly approaches it. From this, we can see that “subtracting Ax_{t-1} from both sides” is the key. This makes the first iteration of [eq:rnn-iter] close to a first-order linear approximation of the original non-linear RNN, which improves convergence speed. This is a classic operation in mathematical physics known as “perturbation.”

Accelerating Convergence

According to the idea of perturbation methods, the key to increasing convergence speed is to improve the accuracy of the approximate expansion. For example, a simpler improvement is to assume only x_{t-1} is a small quantity. Then, based on the first-order Taylor expansion (treating u_t as a column vector, where \circ denotes the Hadamard product): \begin{equation} x_t = \tanh(Ax_{t-1} + u_t) \approx \tanh(u_t) + (\mathop{\mathrm{sech}}^2 u_t \circ A)x_{t-1} \end{equation} Consequently, the improved result is that Equation [eq:rnn-iter] becomes: \begin{equation} x_t^{(n)} - A_t x_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - A_t x_{t-1}^{(n-1)}\label{eq:iter-plus1} \end{equation} where A_t = \mathop{\mathrm{sech}}^2 u_t \circ A. A more refined improvement is to perform the expansion based on the result of the previous iteration step at each iteration: \begin{equation} \begin{aligned} x_t =&\, \tanh(Ax_{t-1} + u_t) \\ \approx&\, \tanh(Ax_{t-1}^{(n-1)} + u_t) + (\mathop{\mathrm{sech}}^2 (Ax_{t-1}^{(n-1)} + u_t) \circ A)(x_{t-1} - x_{t-1}^{(n-1)}) \end{aligned} \end{equation} Thus, Equation [eq:rnn-iter] becomes: \begin{equation} x_t^{(n)} - A_t^{(n)} x_{t-1}^{(n)} = \tanh(Ax_{t-1}^{(n-1)} + u_t) - A_t^{(n)} x_{t-1}^{(n-1)}\label{eq:iter-plus2} \end{equation} where A_t^{(n)} = \mathop{\mathrm{sech}}^2 (Ax_{t-1}^{(n-1)} + u_t) \circ A. This final iterative format is actually “Newton’s method” for finding numerical solutions to equations, which possesses quadratic convergence speed.

Why Bother with Convergence?

Theoretically, the two improvements in [eq:iter-plus1] and [eq:iter-plus2] indeed increase convergence speed. However, they make the matrix A in each linear recursion step dependent on t or even n, which significantly increases the complexity of parallelization and prevents the use of the diagonalization trick from the “Simplified Form” section for acceleration. On the other hand, if the iterative format of [eq:rnn-iter] is maintained, while there are many efficiency benefits, convergence cannot be well guaranteed.

Is the contradiction between these two irreconcilable? In fact, from the author’s perspective, the most direct approach is to “not worry about it.” After deriving [eq:rnn-iter] with the help of non-linear RNNs, one can forget the original non-linear RNN and treat Equation [eq:rnn-iter] as the basic model. That is to say, why worry about whether Equation [eq:rnn-iter] converges to the original non-linear RNN? Why not directly use it as a new starting point? Whatever result gradient descent learns is the result. If the result learned by gradient descent does not converge to the original non-linear RNN, it means that not converging to the original RNN is more suitable.

Once this layer of mental constraint is cast aside, many problems become clear. First, even if Equation [eq:iter-plus2] theoretically has a very good convergence speed, it is conditional, and in the context of deep learning, ensuring these conditions would be a luxury. In other words, even the convergence of Equation [eq:iter-plus2] is not absolutely guaranteed, so why criticize Equation [eq:rnn-iter]? Second, after treating Equation [eq:rnn-iter] as a new starting point, we can simply understand it as a new way to use linear RNNs, or a way to solve the deficiencies of linear RNNs (such as linear RNNs not being Turing complete), which makes it more actionable.

Overall, ignoring convergence seems to better break the mental deadlock and allow for the exploration of more general results.

General Cases

The preceding “lengthy discussion” centered only on simple non-linear RNNs, i.e., Equation [eq:rnn]. What about the more commonly used LSTM and GRU?

Taking GRU as an example, its original form is: \begin{equation} \begin{aligned} z_{t} & = \sigma \left( W_{z} x_{t} + U_{z} h_{t - 1} + b_{z} \right) \\ r_{t} & = \sigma \left( W_{r} x_{t} + U_{r} h_{t - 1} + b_{r} \right) \\ \hat{h}_t & = \tanh \left( W_{h} x_{t} + U_{h} (r_t \circ h_{t - 1}) + b_{c} \right)\\ h_{t} & = \left(1 - z_{t}\right) \circ h_{t - 1} + z_{t} \circ \hat{h}_t \end{aligned} \end{equation} In the initial stage, all gates can be approximately viewed as \frac{1}{2}. Then, imitating Equation [eq:rnn-approx]: \begin{equation} \begin{aligned} h_{t} &\, = \left(1 - z_{t}\right) \circ h_{t - 1} + z_{t} \circ \hat{h}_t \\ &\, \approx \frac{1}{2} h_{t - 1} + \frac{1}{2} \hat{h}_t \\ &\, \approx \frac{1}{2} h_{t - 1} + \frac{1}{2} \left(\tanh ( W_{h} x_{t} + b_{c} ) + \frac{1}{2}U_{h} h_{t - 1}\right) \\ &\, = \frac{1}{2} \left(I + \frac{1}{2}U_{h}\right)h_{t - 1} + \frac{1}{2} \tanh ( W_{h} x_{t} + b_{c} ) \end{aligned} \end{equation} So we can choose A=\frac{1}{2} \left(I + \frac{1}{2}U_{h}\right) and rewrite the GRU as an iteration: \begin{equation} \begin{aligned} z_{t}^{(n)} & = \sigma \left( W_{z} x_{t} + U_{z} h_{t - 1}^{(n-1)} + b_{z} \right) \\ r_{t}^{(n)} & = \sigma \left( W_{r} x_{t} + U_{r} h_{t - 1}^{(n-1)} + b_{r} \right) \\ \hat{h}_t^{(n)} & = \tanh \left( W_{h} x_{t} + U_{h} (r_t^{(n)} \circ h_{t - 1}^{(n-1)}) + b_{c} \right)\\ h_{t}^{(n)} & = Ah_{t-1}^{(n)} - Ah_{t-1}^{(n - 1)} + \left(1 - z_{t}^{(n)}\right) \circ h_{t - 1}^{(n-1)} + z_{t}^{(n)} \circ \hat{h}_t^{(n)} \end{aligned} \end{equation}

In general, this transformation of non-linear RNNs into linear RNN iterations, from a practical perspective, uses the non-linear RNN as a guide to derive a parameter sharing and combination method for multi-layer linear RNNs. If it iterates n times, then it has the computational cost of n layers of linear RNNs. This naturally leads to a thought: unless it can be proven that non-linear RNNs like GRU and LSTM have an absolute advantage, wouldn’t it be better to just stack several layers of “linear RNN + MLP”?

Summary

This article briefly explored the parallel computation problem of non-linear RNNs. Through the “perturbation” idea from mathematical physics, we can transform non-linear RNNs into iterations of linear RNNs, thereby utilizing the parallelizability of linear RNNs to achieve parallelization for non-linear RNNs.

Reprinting: Please include the original address of this article:
https://kexue.fm/archives/9783

For more detailed reprinting matters, please refer to: “Scientific Space FAQ”