English (unofficial) translations of posts at kexue.fm
Source

MoE Tour: 2. Not Worried about Scarcity, but about Inequality

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

In the previous article "MoE Tour: 1. Starting from Geometric Meaning", we introduced a geometric interpretation of MoE, aiming to derive and understand MoE starting from the best approximation of a Dense model. At the end of that article, we also mentioned that providing the calculation formula for MoE is only the beginning. Training a practically effective MoE model requires many details to be addressed, such as the Load Balance problem discussed in this article.

Load balance, or "not being worried about scarcity but about inequality," simply means making sure every Expert is working and that they are all doing as much work as possible, avoiding the waste of computational power by certain Experts. Load balancing is both a requirement for fully utilizing training computational power and a necessity for maximizing the potential of the large parameter count in MoE.

Demand Analysis

We know that the basic form of MoE is: \begin{equation} \boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho}} \rho_i \boldsymbol{e}_i \end{equation} For traditional MoE, \boldsymbol{\rho} is a probability distribution (Router), \boldsymbol{e}_i=\boldsymbol{v}_i, and \boldsymbol{v}_i is the output of a small FFN (Expert). For the geometric MoE we derived in the previous article, \boldsymbol{\rho} does not have a normalization requirement; it predicts the magnitude of the Expert, while \boldsymbol{e}_i=\boldsymbol{v}_i/\Vert\boldsymbol{v}_i\Vert predicts the direction of the Expert.

Regardless of the format, the actual performance of MoE is similar; only the perspective of understanding differs. However, it should be noted that although the MoE formula gives the impression that "every time a Token is encountered, the corresponding Expert is found for calculation," in actual training, it is reversed: computational power is first allocated to each Expert, and then Tokens are distributed (routed) to their respective Experts for parallel calculation. This is why \boldsymbol{\rho}, which is responsible for scoring, is called the Router.

Consequently, if the distribution of Experts is unbalanced, the following situations may occur: some Experts (Dead Experts) remain idle almost all the time, wasting computational power; other Experts have too many Tokens to process and cannot keep up, leading to Token Drop (i.e., giving up on processing some Tokens). Theoretically, the appearance of Dead Experts means that the MoE has not reached its expected parameter capacity—meaning you paid for the VRAM of a large parameter model but only trained the effect of a small parameter model.

Therefore, whether from the perspective of training or performance, we hope to ensure the load balance of Experts.

Auxiliary Loss

The conventional approach to promoting load balance is to add a related loss function, which we usually call "Aux Loss (Auxiliary Loss)." The most mainstream Aux Loss currently in use can be traced back to the 2020 paper "GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding".

Before introducing Aux Loss, we need to introduce some new concepts. First, as mentioned, for a general MoE, \boldsymbol{\rho} may not necessarily be a probability distribution. We denote the normalized \boldsymbol{\rho} as \boldsymbol{p}=[p_1,p_2,\cdots,p_n], and its Top-k version as \boldsymbol{f}=[f_1,f_2,\cdots,f_n], where: \begin{equation} p_i = \frac{\rho_i}{\sum_{i=1}^n \rho_i},\qquad f_i = \left\{\begin{aligned}1/k, \quad i\in \mathop{\text{argtop}}\nolimits_k \boldsymbol{\rho} \\ 0, \quad i\not\in \mathop{\text{argtop}}\nolimits_k \boldsymbol{\rho}\end{aligned}\right. \end{equation} Next, we define \boldsymbol{P}=\mathbb{E}[\boldsymbol{p}] and \boldsymbol{F}=\mathbb{E}[\boldsymbol{f}], where \mathbb{E} refers to the average over all Tokens of all samples. It is easy to see that \boldsymbol{F} is the current load distribution of the Experts, while \boldsymbol{P} is equivalent to a smooth approximation of \boldsymbol{F}.

With these notations, we can write the Aux Loss as: \begin{equation} \mathcal{L}_{\text{aux}} = \boldsymbol{F}\cdot \boldsymbol{P} = \sum_{i=1}^n F_i P_i \label{eq:aux-loss} \end{equation} General literature defining Aux Loss might multiply it by n, meaning their Aux Loss equals n \mathcal{L}_{\text{aux}} here. Additionally, some large-scale MoEs might calculate Aux Loss per device to achieve balance within the device and reduce inter-device communication; these are variations. However, some recent experiments suggest that forcing local balance might very likely affect the final performance of the model.

Straight-Through Estimator

I wonder if anyone has noticed a strange phenomenon: whether in the original source, subsequent literature, or popular science articles, the citation of Aux Loss is always given without proof, as if everyone agrees that the fact that the above Aux Loss can promote balance is self-evident. But is it really that obvious?

In any case, I couldn’t see it immediately. Therefore, I will provide a derivation logic for Equation [eq:aux-loss]. From this logic, we can also customize other forms of Aux Loss. First, define the uniform distribution \boldsymbol{Q}=(1/n,1/n,\cdots,1/n). As we just said, \boldsymbol{F} is the current load distribution; therefore, load balance is equivalent to \boldsymbol{F}=\boldsymbol{Q}. Thus, the following is a more intuitive Aux Loss: \begin{equation} \mathcal{L}_{\text{aux}} = \frac{1}{2}\Vert\boldsymbol{F} - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (F_i - 1/n)^2 \label{eq:aux-loss-2} \end{equation} The problem is that \boldsymbol{F} is derived from \mathop{\text{argtop}}_k, which means the above equation is not a directly usable differentiable objective. How do we solve this? The answer is the STE (Straight-Through Estimator) trick, where we design different functions for forward and backward propagation. Specifically, while \boldsymbol{F} is non-differentiable, \boldsymbol{P} as its smooth approximation is differentiable. So, we can replace \boldsymbol{F} with \boldsymbol{P} during backpropagation, i.e.: \begin{equation} \mathcal{L}_{\text{aux}} = \frac{1}{2}\Vert \boldsymbol{P} + \text{sg}[\boldsymbol{F}-\boldsymbol{P}] - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (P_i + \text{sg}[F_i - P_i] - 1/n)^2 \label{eq:aux-loss-3} \end{equation} where \text{sg}[] is the stop gradient operator, which keeps the forward output unchanged but forces the gradient to be zero. After this modification, \mathcal{L}_{\text{aux}} becomes a feasible Aux Loss. Let’s try to find its gradient: \begin{equation} \begin{aligned} \nabla_{\boldsymbol{\theta}}\mathcal{L}_{\text{aux}} =&\, \frac{1}{2}\nabla_{\boldsymbol{\theta}}\sum_{i=1}^n (P_i + \text{sg}[F_i - P_i] - 1/n)^2 \\ =&\, \sum_{i=1}^n (P_i + \text{sg}[F_i - P_i] - 1/n) \nabla_{\boldsymbol{\theta}}(P_i + \text{sg}[F_i - P_i] - 1/n)\\ =&\, \sum_{i=1}^n (F_i - 1/n) \nabla_{\boldsymbol{\theta}}P_i = \nabla_{\boldsymbol{\theta}}\sum_{i=1}^n (F_i - 1/n) P_i\\ =&\, \nabla_{\boldsymbol{\theta}}\left(\sum_{i=1}^n F_i P_i\right) \end{aligned} \end{equation} Here \boldsymbol{\theta} represents the model parameters. The final result shows that the gradient of Equation [eq:aux-loss-3] is equal to the gradient of Equation [eq:aux-loss]. This means using Equation [eq:aux-loss] as the Aux Loss is equivalent to Equation [eq:aux-loss-3] in terms of gradients, which is why the Aux Loss in Equation [eq:aux-loss] appeared.

However, Equation [eq:aux-loss] only has meaning in terms of equivalent gradients but does not have the meaning of a Loss function itself. It is not a "true" Loss. For example, when \boldsymbol{F} = \boldsymbol{P}, we can calculate that Equation [eq:aux-loss] equals 1/n, but in fact, we can construct an \boldsymbol{F} not equal to \boldsymbol{P} that makes it smaller than 1/n. Therefore, Equation [eq:aux-loss] is not like a normal Loss that is better when smaller, nor is its minimum necessarily reached when \boldsymbol{F} = \boldsymbol{P}.

General Form

The above derivation actually provides a general idea for constructing Aux Loss: First, construct a loss that meets the requirements based on \boldsymbol{F}, and then replace \boldsymbol{F} with \boldsymbol{P} + \text{sg}[\boldsymbol{F}-\boldsymbol{P}] in the implementation. For example, we know that maximum entropy can also push a distribution toward balance; therefore, we can also use the negative of entropy to construct an Aux Loss: \begin{equation} \mathcal{L}_{\text{aux}} = \sum_{i=1}^n (P_i + \text{sg}[F_i - P_i])\log(P_i + \text{sg}[F_i - P_i]) \end{equation} The above equation can be used directly in code implementation. Of course, if we seek simplification, we can similarly find the gradient, and the result will be: \begin{equation} \nabla_{\boldsymbol{\theta}}\mathcal{L}_{\text{aux}} = \nabla_{\boldsymbol{\theta}}\sum_{i=1}^n(P_i + \text{sg}[F_i - P_i]) \log(P_i + \text{sg}[F_i - P_i]) = \nabla_{\boldsymbol{\theta}}\sum_{i=1}^n P_i \log F_i \end{equation} In the process of simplifying the gradients twice, we used the following identity: \begin{equation} \sum_{i=1}^n \nabla_{\boldsymbol{\theta}}P_i = \nabla_{\boldsymbol{\theta}}\sum_{i=1}^n P_i = \nabla_{\boldsymbol{\theta}}1 = \boldsymbol{0} \end{equation} This relies on the fact that \boldsymbol{P} is a probability distribution and the target distribution \boldsymbol{Q} is a uniform distribution. If we do not seek the simplified equivalent result but directly use the Aux Loss in the form of \boldsymbol{F}\to \boldsymbol{P} + \text{sg}[\boldsymbol{F}-\boldsymbol{P}], then we are not subject to these two constraints.

For instance, regarding \boldsymbol{P} as a smooth approximation of \boldsymbol{F}, we only used the property that "when P_i is large, F_i is usually also large." Therefore, using the non-normalized \mathbb{E}[\boldsymbol{\rho}] as \boldsymbol{P} is usually fine. This point might be critical in some special scenarios (such as when \boldsymbol{\rho} has both positive and negative values) because normalization into a probability distribution is impossible in such cases. Another example is the objective \Vert\boldsymbol{F} - \boldsymbol{Q}\Vert^2, which can clearly push \boldsymbol{F} toward any target distribution \boldsymbol{Q} we want, not necessarily a uniform one.

Summary

This article introduced the load balance problem in MoE and provided a general idea for constructing Aux Loss. Besides Aux Loss, there are other solutions to promote load balance, which we will discuss next time.

Original address: https://kexue.fm/archives/10735

For more details on reprinting, please refer to: "Scientific Space FAQ"