In recent years, RNNs (Recurrent Neural Networks) have regained the interest of many researchers and users due to their linear training and inference efficiency, showing signs of a "Renaissance." Representative works include RWKV, RetNet, and Mamba. When RNNs are used for language models, their typical characteristic is that each generation step has constant space and time complexity; from the perspective of the entire sequence, this results in constant space complexity and linear time complexity. Of course, everything has two sides. Compared to the dynamically growing KV Cache of Attention, the constant space complexity of RNNs often leads to suspicions of limited memory capacity, making it difficult to match the performance of Attention on Long Context tasks.
In this article, we show that Causal Attention can be rewritten in the form of an RNN, and each step of its generation can theoretically be performed with \mathcal{O}(1) space complexity (at the cost of extremely high time complexity, far exceeding quadratic). This suggests that the advantage of Attention (if any) is built on computational stacking rather than an intuitive stacking of memory; like RNNs, it essentially has a constant-magnitude memory capacity (memory bottleneck).
RNNs Beyond Linearity
Supporters of RNNs often present an argument that seems hard to refute: "Is your brain an RNN or Attention?"
Intuitively, the space complexity of RNN inference is constant, while the KV cache of Attention grows dynamically. Considering that human brain capacity is finite, from this point of view, RNNs indeed seem closer to the human brain. However, even if it is reasonable to believe that brain capacity limits the space complexity of each inference step to a constant, it does not limit the time complexity of each step to a constant. Or to put it another way, even if a person’s time complexity per step is constant, they may not scan a sequence of length L only once (like "flipping through a book"). Thus, the total number of inference steps may significantly exceed L, leading to non-linear time complexity.
Taking this into account, I had a "sudden inspiration": Can we
generalize the consideration of RNN models with constant space
complexity and non-linear time complexity to supplement the capabilities
that mainstream RNNs lack (such as the "book-flipping" mentioned above)?
For a language modeling task, assuming the samples are
a b c d e, the training task is to input
a b c d and predict b c d e. A common RNN is
shown below:
The problem with this type of RNN is the lack of "book-flipping" capability; each input is discarded after being read. The characteristic of Attention is that for every token read, it completely "flips through" the entire history. While this approach may have efficiency issues, it is undoubtedly the most straightforward way to introduce book-flipping capability. To supplement RNNs with this capability, we can imitate Attention’s approach to using RNNs:
Just like Attention, every time a new token is read, the entire history is reviewed. Of course, one could say this isn’t a new RNN design, but rather a new way of using RNNs by simply modifying the input; whether it’s RWKV or Mamba, this can be applied. Under this usage, decoding can still be completed within constant space complexity, but the time complexity of each inference step grows linearly, resulting in a total time cost of \mathcal{O}(L^2).
Attention is also an RNN
In fact, the model represented by Figure 2 is very broad; even Attention is just a special case of it, as shown below:
Compared to Figure 2, several arrows in Figure 3 are faded, representing that these positions are actually disconnected. Thus, Attention is merely a special case of Figure 2. Specifically, the calculation formula for Attention is: \begin{equation} o_i = \sum_{j=1}^i a_{i,j}v_j = \frac{\sum_{j=1}^i e^{q_i\cdot k_j} v_j}{\sum_{j=1}^i e^{q_i\cdot k_j}} \end{equation} Clearly, the summation in both the numerator and denominator can be written in recursive form: \begin{equation} \begin{pmatrix} y_i^{(t)} \\ z_i^{(t)} \end{pmatrix} = \begin{pmatrix} y_i^{(t-1)} \\ z_i^{(t-1)} \end{pmatrix} + e^{q_i\cdot k_{i-t+1}}\begin{pmatrix} v_{i-t+1} \\ 1 \end{pmatrix}\quad,\quad o_i = \frac{y_i^{(i)}}{z_i^{(i)}} \end{equation} According to the literature I have read, the earliest paper to propose the above formula and use it to optimize Attention calculation is "Self-attention Does Not Need O(n^2) Memory". The block matrix version of the above formula is the theoretical foundation of the current mainstream acceleration technology, Flash Attention. Since in Self-Attention, Q, K, V are all obtained from the same input through token-wise operations, the recursive form above can be exactly represented as Figure 3.
Of course, Figure 3 only illustrates one layer of Attention. Multiple layers can naturally be drawn, but the connections would look somewhat complex. For example, the case for two layers is shown below:
Constant Space Complexity
As stated at the beginning of this article, a common advantage of RNNs is that they can perform inference with constant space complexity and linear time complexity. Since Attention can also be written as an RNN, the natural question is: does it also have these two advantages under this formulation?
Obviously, since the RNN corresponding to Attention is an RNN where the sequence length has increased to \mathcal{O}(L^2), linear time complexity is out of the question. The only thing worth considering is whether constant space complexity can be achieved. One’s first reaction might be "no," because it is well known that Attention decoding requires a dynamically growing linear KV cache. However, this is only the case for efficient implementations. If we trade time for space regardless of the cost, how much can the space complexity be further reduced?
The answer might be surprising: If time is truly traded for space to the extreme, the space complexity can indeed be reduced to \mathcal{O}(1)!
This conclusion is not hard to imagine. First, the single-layer Attention shown in Figure 3 is no different in form from an ordinary single-layer RNN; therefore, it can obviously complete inference using a fixed amount of storage space. Next, let’s look at the multi-layer Attention shown in Figure 4. Its connections between layers are more complex, so it usually requires caching historical K and V for efficient calculation. But if we resolutely refuse to store the KV cache, then the K and V input for each layer and each inference step can be completely recomputed from the original input (recomputation). This leads to a lot of redundant calculations, so the total time complexity will far exceed quadratic complexity and is very "un-environmentally friendly," but the space complexity can indeed be maintained at \mathcal{O}(1).
Taking two-layer Attention as an example, the second layer of Attention uses the output of the first layer as input. Since each output of the first layer can be calculated in \mathcal{O}(1) space, as long as we are willing to sacrifice efficiency for recomputation, the second layer of Attention also only needs \mathcal{O}(1) space to complete. By extension, the third layer uses the output of the second layer, and the N-th layer uses the output of the (N-1)-th layer. Since the previous layer can always be completed in \mathcal{O}(1) space through recomputation, each layer and even the entire model can be computed in \mathcal{O}(1) space.
This returns to the viewpoint at the beginning of the article: if Attention truly has any advantage over RNNs, it is only achieved through more computation. The intuitive expansion of "memory" is just an appearance of trading space for time; like RNNs, it essentially has a constant-capacity memory bottleneck.
Of course, some readers might think: "Isn’t trading time for space a common practice? This doesn’t seem like a valuable conclusion." Indeed, trading time for space is common, but it is not always possible. In other words, not all problems can have their space complexity reduced to \mathcal{O}(1) by trading time for space. This is a common but non-trivial property.
Reflections on Model Capabilities
The reason for pointing out this characteristic of Attention is not to actually use it for inference, but to help us further reflect on the capability bottlenecks of Attention.
First, if we really get into the details, \mathcal{O}(1) is actually incorrect; more
strictly, it should be \mathcal{O}(L),
because a quadratic-complexity RNN needs to repeatedly scan the
historical sequence, which at least requires storing the original input
and the outputs of the generation process—i.e., at least L integer token IDs. The space required for
this is \mathcal{O}(L). If L is large enough, \mathcal{O}(L) will be larger than \mathcal{O}(1). However, the \mathcal{O}(1) mentioned here mainly refers
to the minimum space required by the internal computational layers of
the LLM, equivalent to the hidden_state when acting as an
RNN, which has at least (hidden_size *
num_layers * 2) components, while the \mathcal{O}(L) space is reflected in the
input and output. An intuitive analogy is to treat Attention as a
computer with an infinite hard drive but fixed memory (RAM); it
constantly reads data from the hard drive, performs calculations in
memory, and writes the results back to the hard drive.
We know that if the memory itself is large and the data being processed is not, we tend to be more "willful" when programming, perhaps even loading all data into memory so that the intermediate calculation process does not rely on hard drive I/O at all. Similarly, LLMs trained in the context of "large models, short sequences" tend to use the \mathcal{O}(1) fixed "memory" brought by model scaling, rather than the dynamic "hard drive" brought by sequence length. Because under the current scale of LLMs, the former is large enough, and SGD will "lazily" train the model as a machine with infinite static memory (because memory is always sufficient for short sequences). But in reality, the model’s static memory is finite. Therefore, for tasks that cannot be completed in \mathcal{O}(1) space, Attention-based models cannot generalize to inputs of arbitrary length.
For example, if we want to calculate the decimal representation y of 2^x and use Attention for conditional modeling p(y|x), the training corpus would be the concatenation \{x, \text{\textcolor{red}{[sep]}}, y\}, calculating only the loss for y. Note that y here is uniquely determined by the input x, so theoretically, 100% accuracy should be learnable. However, without a Chain of Thought (CoT) to dynamically increase the sequence length, the model can only place all calculation processes implicitly in "memory," which is always effective for short inputs. But in fact, memory is finite, while the space required to calculate 2^x increases with x. Therefore, there must exist a sufficiently large x such that the accuracy of p(y|x) cannot reach 100% (even for training accuracy). This is different from the length extrapolation problem discussed in "The Road to Transformer Upgrading: 16. ’Reviewing’ Length Extrapolation Techniques"; it is not caused by the OOD (Out-of-Distribution) of position encodings, but rather a capability defect brought about by "large model, short sequence" training without sufficient CoT guidance.
So why is the current mainstream direction for scaling up still
increasing the LLM’s memory—i.e., increasing the model’s
hidden_size and num_layers—rather than
researching schemes like CoT to increase seq_len? The
latter is certainly one of the mainstream research areas, but the core
issue is that if memory becomes a bottleneck, it reduces the model’s
learning efficiency and universality. It’s like when memory is small but
the data volume is large, we need to save results to the hard drive in a
timely manner and clear the memory, which means the algorithm must be
more sophisticated and harder to write, and may even need to be
customized based on specific tasks. Under what circumstances does a
memory bottleneck occur? Taking LLAMA2-70B as an example, its
num_layers is 80 and hidden_size is 8192;
multiplying them gives 640K, and multiplying by 2 gives about 1M. In
other words, when the input length reaches the level of 1M tokens, the
"memory" of LLAMA2-70B may become a bottleneck. Although training an LLM
with 1M tokens is still not easy, it is no longer out of reach; for
example, Kimi has already launched a 1M-level model for internal
testing.
Therefore, constantly increasing the model’s context length (hard drive) to accommodate more input and CoT, while simultaneously improving the scale of the model itself so that "memory" does not become a bottleneck, has become the main theme of current LLMs.
At the same time, this also negates a previous idea of mine: Can we
achieve the same effect as a large model by reducing the model scale and
increasing seq_len? The answer is likely no, because small
models have memory bottlenecks. To compensate using the hard drive
provided by seq_len, one would need to set a sufficiently
long CoT for every sample, which is more difficult than training a large
model directly. If seq_len is increased only through simple
schemes like repetition, no additional information is brought in, and
there is no substantial gain. However, if the increase in
seq_len is achieved through prefix tuning, it is possible
to bridge the gap in space complexity, because prefix parameters are not
calculated from the input sequence but are trained separately. This is
equivalent to inserting a series of additional "memory sticks," thereby
increasing the model’s memory.
A Final Summary
In this article, we examined Attention from the perspective of a
quadratic-complexity RNN and discovered that it has a constant space
complexity bottleneck. This indicates that Attention does not
essentially increase "memory" compared to RNNs, but merely increases the
amount of computation significantly. The existence of this bottleneck
suggests that Attention may face theoretical difficulties in length
generalization for certain tasks (insufficient memory). Guiding the
model to better utilize the dynamic "hard drive" brought by the
seq_len dimension may be the key to solving this
difficulty.
Reprinting: Please include the original address of this article: https://kexue.fm/archives/10017
Detailed reprinting matters: Please refer to "Scientific Space FAQ".