The comparison between Pre-Norm and Post-Norm is a "long-standing" topic. This blog has discussed it multiple times, such as in the articles "A Brief Discussion on the Initialization, Parameterization, and Normalization of Transformer" and "Model Optimization Notes: Why is the Initial Standard Deviation of BERT 0.02?". Currently, the relatively clear conclusion is: Under the same settings, the Pre-Norm structure is often easier to train, but the final performance is usually not as good as Post-Norm. It is easy to understand why Pre-Norm is easier to train—its identity path is more prominent—but why is its performance not as good?
I previously lacked a good answer to this until I saw a reply by @Tang Xianghao on Zhihu. I suddenly realized that there is a very intuitive explanation for this problem! Let’s explore it together in this article.
Basic Conclusion
The formulas for Pre-Norm and Post-Norm are as follows: \begin{aligned} \text{Pre-Norm: } \quad \boldsymbol{x}_{t+1} &= \boldsymbol{x}_t + F_t(\text{Norm}(\boldsymbol{x}_t)) \\ \text{Post-Norm: }\quad \boldsymbol{x}_{t+1} &= \text{Norm}(\boldsymbol{x}_t + F_t(\boldsymbol{x}_t)) \end{aligned} In Transformers, \text{Norm} mainly refers to Layer Normalization, but in general models, it can also be Batch Normalization, Instance Normalization, etc. The relevant conclusions are essentially universal.
Among the materials I have found, two works show that Post-Norm is superior to Pre-Norm: "Understanding the Difficulty of Training Transformers" and "RealFormer: Transformer Likes Residual Attention". Additionally, I have conducted comparative experiments showing that the transfer performance of the Post-Norm structure is better. That is to say, in pre-training, both Pre-Norm and Post-Norm can achieve roughly the same results, but the fine-tuning effect of Post-Norm is significantly better.
Readers might ask: doesn’t "On Layer Normalization in the Transformer Architecture" show that Pre-Norm is better than Post-Norm? Is this a contradiction? In fact, that paper compares the performance of Pre-Norm and Post-Norm under identical training configurations. This only demonstrates that Pre-Norm is easier to train. To achieve its optimal performance, Post-Norm cannot use the same training configuration as Pre-Norm (for example, Pre-Norm can work without a warmup, but Post-Norm usually requires one). Therefore, the conclusions are not contradictory.
Intuitive Understanding
Why is the performance of Pre-Norm not as good as Post-Norm? The answer given by @Tang Xianghao on Zhihu is: The depth of Pre-Norm is "diluted"! In other words, the actual effective depth of an L-layer Pre-Norm model is less than that of an L-layer Post-Norm model, and the reduced depth leads to poorer performance.
How can we understand this? It is quite simple. For a Pre-Norm model, we can iterate the recurrence: \begin{aligned} \boldsymbol{x}_{t+1} &= \boldsymbol{x}_t + F_t(\text{Norm}(\boldsymbol{x}_t)) \\ &= \boldsymbol{x}_{t-1} + F_{t-1}(\text{Norm}(\boldsymbol{x}_{t-1})) + F_t(\text{Norm}(\boldsymbol{x}_t)) \\ &= \cdots \\ &= \boldsymbol{x}_0 + F_0 (\text{Norm}(\boldsymbol{x}_0)) + \cdots + F_{t-1}(\text{Norm}(\boldsymbol{x}_{t-1})) + F_t(\text{Norm}(\boldsymbol{x}_t)) \end{aligned} Since each term is of the same order of magnitude, we have \boldsymbol{x}_{t+1} = \mathcal{O}(t+1). This means the difference between layer t+1 and layer t is equivalent to the difference between t+1 and t. When t is large, the relative difference between the two is very small. Therefore: \begin{aligned} & F_t(\text{Norm}(\boldsymbol{x}_t)) + F_{t+1}(\text{Norm}(\boldsymbol{x}_{t+1})) \\ \approx & F_t(\text{Norm}(\boldsymbol{x}_t)) + F_{t+1}(\text{Norm}(\boldsymbol{x}_t)) \\ = & \begin{pmatrix} 1 & 1 \end{pmatrix} \begin{pmatrix} F_t \\ F_{t+1} \end{pmatrix} (\text{Norm}(\boldsymbol{x}_t)) \end{aligned} This implies that when t is relatively large, \boldsymbol{x}_t and \boldsymbol{x}_{t+1} are very close, so F_{t+1}(\text{Norm}(\boldsymbol{x}_{t+1})) is very close to F_{t+1}(\text{Norm}(\boldsymbol{x}_t)). Thus, the sum of layer t and layer t+1 in the model is approximately equivalent to a single, wider layer t. Consequently, in Pre-Norm, the stacking of multiple layers tends to increase width rather than depth. The more layers there are, the more "hollow" those layers become.
Simply put, the Pre-Norm structure inadvertently increases the model’s width while decreasing its depth. Since we know that depth is generally more important than width, this unintentional reduction in depth leads to worse final performance. Post-Norm is exactly the opposite. As analyzed in "A Brief Discussion on the Initialization, Parameterization, and Normalization of Transformer", every time it performs a Norm, it weakens the weight of the identity branch. Thus, Post-Norm emphasizes the residual branch more. Therefore, the layers in Post-Norm are "full-weight," and once trained well, the performance is superior.
Summary
This article mainly shares an intuitive understanding of "why the performance of Pre-Norm is not as good as Post-Norm."
Original Address: https://kexue.fm/archives/9009
For more details on reposting, please refer to: "Scientific Space FAQ"