English (unofficial) translations of posts at kexue.fm
Source

Softmax Sequel: Finding Smooth Approximations for Top-K

Translated by Gemini Flash 3.0 Preview. Translations can be inaccurate, please refer to the original post for important stuff.

Softmax, as the name suggests, is a "soft" version of the max operator. Specifically, it is a smooth approximation of the \text{argmax} operator. It transforms any vector \boldsymbol{x}\in\mathbb{R}^n into a new vector with non-negative components that sum to 1 through exponential normalization. It also allows us to adjust the degree of approximation to the one-hot form of \text{argmax} via a temperature parameter. In addition to exponential normalization, we previously introduced other schemes that achieve similar effects in "The Road to Probability Distributions: A Survey of Softmax and its Alternatives".

We know that the maximum value is often referred to as Top-1, and its smooth approximation schemes appear to be quite mature. However, have you ever wondered what a smooth approximation for a general Top-k operator looks like? Let us explore this problem together.

Problem Description

Let \boldsymbol{x}=(x_1,x_2,\cdots,x_n)\in\mathbb{R}^n. For simplicity, we assume that the components are pairwise distinct, i.e., i\neq j \Leftrightarrow x_i\neq x_j. Let \Omega_k(\boldsymbol{x}) be the set of indices of the k largest components of \boldsymbol{x}, such that |\Omega_k(\boldsymbol{x})|=k and \forall i\in \Omega_k(\boldsymbol{x}), j \not\in \Omega_k(\boldsymbol{x})\Rightarrow x_i > x_j. We define the Top-k operator \mathcal{T}_k as a mapping from \mathbb{R}^n\mapsto\{0,1\}^n: \begin{equation} [\mathcal{T}_k(\boldsymbol{x})]_i = \left\{\begin{aligned}1,\,\, i\in \Omega_k(\boldsymbol{x}) \\ 0,\,\, i \not\in \Omega_k(\boldsymbol{x})\end{aligned}\right. \end{equation} In simple terms, if x_i is among the k largest elements, the corresponding position becomes 1; otherwise, it becomes 0. The final result is a multi-hot vector. For example, \mathcal{T}_2([3,2,1,4]) = [1,0,0,1].

The transition from \boldsymbol{x} to \mathcal{T}_k(\boldsymbol{x}) is essentially a "hard assignment" operation. It is inherently discontinuous and does not preserve useful gradients for \boldsymbol{x}, making it impossible to integrate into a model for end-to-end training. To solve this, we need to construct a smooth approximation of \mathcal{T}_k(\boldsymbol{x}) that provides effective gradient information—often referred to in literature as a "Differentiable Top-k Operator."

Specifically, we define the set: \begin{equation} \Delta_k^{n-1} = \left\{\boldsymbol{p}=(p_1,p_2,\cdots,p_n)\left|\, p_1,p_2,\cdots,p_n\in[0,1],\sum_{i=1}^n p_i = k\right.\right\} \end{equation} Our goal is to construct a mapping \mathcal{ST}_k(\boldsymbol{x}): \mathbb{R}^n\mapsto \Delta_k^{n-1} that satisfies the following properties as much as possible: \begin{align} &\text{\textcolor{red}{Monotonicity}}:\quad [\mathcal{ST}_k(\boldsymbol{x})]_i \geq [\mathcal{ST}_k(\boldsymbol{x})]_j \,\,\Leftrightarrow\,\, x_i \geq x_j \\[8pt] &\text{\textcolor{red}{Invariance}}:\quad \mathcal{ST}_k(\boldsymbol{x}) = \mathcal{ST}_k(\boldsymbol{x} + c),\,\,\forall c\in\mathbb{R} \\[8pt] &\text{\textcolor{red}{Convergence}}:\quad \lim_{\tau\to 0^+}\mathcal{ST}_k(\boldsymbol{x}/\tau) = \mathcal{T}_k(\boldsymbol{x}) \end{align} It can be verified that Softmax, as \mathcal{ST}_1(\boldsymbol{x}), satisfies these properties. Thus, proposing these properties is essentially an attempt to make the constructed \mathcal{ST}_k(\boldsymbol{x}) a natural generalization of Softmax. Of course, constructing a smooth approximation for Top-k is inherently more difficult than for Top-1, so if difficulties arise, we do not strictly need to follow all properties, as long as the mapping exhibits the characteristics of a smooth approximation of \mathcal{T}_k(\boldsymbol{x}).

Iterative Construction

In fact, I have been interested in this problem for a long time. It was first discussed in a 2019 article "Random Talk on Function Smoothing: Differentiable Approximations of Non-differentiable Functions", where it was called \text{soft-}k\text{-max}, and an iterative construction scheme was provided:

Input \boldsymbol{x}, initialize \boldsymbol{p}^{(0)} as an all-zero vector;
Execute \boldsymbol{x} = \boldsymbol{x} - \min(\boldsymbol{x}) (to ensure all elements are non-negative).

For i=1,2,\dots,k, execute:
\boldsymbol{y} = (1 - \boldsymbol{p}^{(i-1)})\otimes\boldsymbol{x};
\boldsymbol{p}^{(i)} = \boldsymbol{p}^{(i-1)} + \text{softmax}(\boldsymbol{y})

Return \boldsymbol{p}^{(k)}.

The logic behind this iteration is simple. We can understand it by first replacing \text{softmax}(\boldsymbol{y}) with \mathcal{T}_1(\boldsymbol{y}). In that case, the process ensures non-negativity, identifies the Top-1, sets it to zero (making the maximum value the minimum), identifies the next Top-1, and so on. The final \boldsymbol{p}_k would be exactly \mathcal{T}_k(\boldsymbol{x}). Since \text{softmax}(\boldsymbol{y}) is a smooth approximation of \mathcal{T}_1(\boldsymbol{y}), using it in the iteration naturally yields a smooth approximation of \mathcal{T}_k(\boldsymbol{x}).

Coincidentally, I found a similar idea in a response to the Stack Exchange question "Is there something like softmax but for top k values?". The respondent proposed a weighted Softmax: \begin{equation} [\text{softmax}(\boldsymbol{x};\boldsymbol{w})]_i = \frac{w_i e^{x_i}}{\sum\limits_{i=1}^n w_i e^{x_i}} \end{equation} And the iterative process:

Input \boldsymbol{x}, initialize \boldsymbol{p}^{(0)} as an all-zero vector;

For i=1,2,\dots,k, execute:
\boldsymbol{p}^{(i)} = \boldsymbol{p}^{(i-1)} + \text{softmax}(\boldsymbol{x}; 1 - \boldsymbol{p}^{(i-1)})

Return \boldsymbol{p}^{(k)}.

This is conceptually identical to my iterative process, except I multiplied 1 - \boldsymbol{p}_{i-1} by \boldsymbol{x}, while they multiplied it by e^{\boldsymbol{x}}, simplifying the process by leveraging the non-negativity of e^{\boldsymbol{x}}. However, this iteration is actually incorrect because it does not satisfy "Convergence". For instance, when k=2, taking the limit \tau\to 0^+ for \boldsymbol{x}/\tau does not result in a multi-hot vector. Instead, the maximum value becomes 1.5, the second largest becomes 0.5, and the rest become 0. This is because 1-p_{\max} is roughly of the same order as e^{-x_{\max}}, so multiplying 1-p_{\max} by e^{x_{\max}} does not completely eliminate the maximum.

As a Gradient

Iterative constructions are based on heuristics and may hide subtle issues, such as the seemingly simpler weighted Softmax iteration being invalid. Without a more fundamental guiding principle, these schemes are difficult to analyze theoretically. For example, while my iterative scheme tests well, it is hard to prove that the components of \boldsymbol{p}_k are within [0,1] or to determine if it satisfies Monotonicity.

Therefore, we seek a higher-level principle to guide the design of this smooth approximation. A few days ago, I realized a key fact: \begin{equation} \mathcal{T}_k(\boldsymbol{x}) = \nabla_{\boldsymbol{x}} \sum_{i\in\Omega_k(\boldsymbol{x})} x_i \end{equation} In other words, the gradient of the sum of the k largest components is exactly \mathcal{T}_k(\boldsymbol{x}). Thus, we can switch to finding a smooth approximation for \sum_{i\in\Omega_k(\boldsymbol{x})} x_i and then take the gradient. The former is a scalar, which is easier to approximate. For example, using the identity: \begin{equation} \sum_{i\in\Omega_k(\boldsymbol{x})} x_i = \max_{i_1 < \cdots < i_k} (x_{i_1} + \cdots + x_{i_k}) \end{equation} This means we iterate through all sums of k components and take the maximum. Now the problem becomes finding a smooth approximation for \max, which we have already solved (see "Seeking a Smooth Maximum Function"). The answer is \text{logsumexp}: \begin{equation} \max_{i_1 < \cdots < i_k} (x_{i_1} + \cdots + x_{i_k})\approx \log\sum_{i_1 < \cdots < i_k} e^{x_{i_1} + \cdots + x_{i_k}}\triangleq \log Z_k \end{equation} Taking the gradient, we obtain a form for \mathcal{ST}_k(\boldsymbol{x}): \begin{equation} [\mathcal{ST}_k(\boldsymbol{x})]_i = \frac{\sum\limits_{i_2 < \cdots < i_k} e^{x_i+x_{i_2} + \cdots + x_{i_k}}}{\sum\limits_{i_1 < \cdots < i_k} e^{x_{i_1} +x_{i_2}+ \cdots + x_{i_k}}}\triangleq \frac{Z_{k,i}}{Z_k}\label{eq:k-max-grad} \end{equation} The denominator is the sum of exponentials of all k-component sums, and the numerator is the sum of exponentials of all k-component sums that include x_i. From this form, we can easily prove: \begin{equation} 0 < [\mathcal{ST}_k(\boldsymbol{x})]_i < 1,\quad \sum_{i=1}^n [\mathcal{ST}_k(\boldsymbol{x})]_i = k \end{equation} Thus, \mathcal{ST}_k(\boldsymbol{x}) defined this way indeed belongs to \Delta_k^{n-1}. In fact, we can also prove it satisfies Monotonicity, Invariance, and Convergence. Furthermore, \mathcal{ST}_1(\boldsymbol{x}) is exactly Softmax. These properties show it is a natural generalization of Softmax for the Top-k operator. We shall call it "GradTopK (Gradient-guided Soft Top-k operator)."

However, it is not yet time to celebrate, as the numerical calculation of Equation [eq:k-max-grad] remains unresolved. Calculating it directly involves C_n^k terms in the denominator, which is computationally expensive. We must find an efficient method. We have denoted the numerator and denominator as Z_{k,i} and Z_k. We can observe that Z_{k,i} satisfies the recurrence: \begin{equation} Z_{k,i} = e^{x_i}(Z_{k-1} - Z_{k-1,i}) \end{equation} Combined with the fact that the sum of Z_{k,i} over i equals kZ_k, we can construct a recursive calculation: \begin{equation} \begin{aligned} \log Z_{k,i} =&\, x_i + \log(e^{\log Z_{k-1}} - e^{\log Z_{k-1,i}}) \\ \log Z_k =&\, \left(\log\sum_{i=1}^n e^{\log Z_{k,i}}\right) - \log k \\ \end{aligned} \end{equation} where \log Z_{1,i} = x_i. To reduce overflow risk, we take the logarithm of both sides. Now, calculating \mathcal{ST}_k(\boldsymbol{x}) only requires k iterations, which is efficient enough. However, even with logarithmic processing, this recursion only works for \boldsymbol{x} with small variance or small k. Otherwise, \log Z_{k-1} and the largest \log Z_{k-1,i} become so close that numerical precision fails, leading to a \log 0 bug. I believe this is a fundamental difficulty in this recursive transformation.

A simple reference implementation:

import numpy as np

def GradTopK(x, k):
    for i in range(1, k + 1):
        logZs = x if i == 1 else x + logZ + np.log(1 - np.exp(logZs - logZ))
        logZ = np.logaddexp.reduce(logZs) - np.log(i)
    return np.exp(logZs - logZ)

k, x = 10, np.random.randn(100)
GradTopK(x, k)

Undetermined Constant

The previous approach of using gradients to build a Top-k smooth approximation offers a high-level aesthetic, but some readers might find it too abstract. Furthermore, the numerical instability for large variance \boldsymbol{x} or large k is unsatisfying. Next, we explore a bottom-up construction.

Method Overview

This idea comes from a response to another Stack Exchange post "Differentiable top-k function". Let f(x) be any smooth, monotonically increasing function from \mathbb{R}\mapsto [0,1] such that \lim\limits_{x\to\infty}f(x) = 1 and \lim\limits_{x\to-\infty}f(x) = 0. Such functions are easy to construct, such as the classic Sigmoid \sigma(x)=1/(1+e^{-x}), \text{clip}(x,0,1), or \min(1, e^x). Consider: \begin{equation} f(\boldsymbol{x}) = [f(x_1),f(x_2),\cdots,f(x_n)] \end{equation} How far is f(\boldsymbol{x}) from our desired \mathcal{ST}_k(\boldsymbol{x})? Each component is already in [0,1], but the sum is not guaranteed to be k. Thus, we introduce an undetermined constant \lambda(\boldsymbol{x}) to ensure this: \begin{equation} \mathcal{ST}_k(\boldsymbol{x}) \triangleq f(\boldsymbol{x} - \lambda(\boldsymbol{x})),\quad \sum_{i=1}^n f(x_i - \lambda(\boldsymbol{x})) = k \end{equation} We solve for \lambda(\boldsymbol{x}) such that the sum of components is k. We can call this "ThreTopK (Threshold-adjusted Soft Top-k operator)." If you have read "The Road to Probability Distributions: A Survey of Softmax and its Alternatives", you will recognize that this approach is the same as Sparsemax and Entmax-\alpha.

Is ThreTopK our ideal \mathcal{ST}_k(\boldsymbol{x})? Indeed! First, since f is monotonic, Monotonicity is satisfied. Second, f(\boldsymbol{x} - \lambda(\boldsymbol{x}))=f(\boldsymbol{x}+c - (c+\lambda(\boldsymbol{x}))), meaning the constant can be absorbed into \lambda(\boldsymbol{x}), so Invariance is satisfied. Finally, as \tau\to 0^+, we can find an appropriate threshold \lambda(\boldsymbol{x}/\tau) such that the k largest components of \boldsymbol{x}/\tau-\lambda(\boldsymbol{x}/\tau) tend to \infty and the rest tend to -\infty, making f(\boldsymbol{x}/\tau-\lambda(\boldsymbol{x}/\tau)) equal to \mathcal{T}_k(\boldsymbol{x}), satisfying Convergence.

Analytical Solution

Having proven the theoretical superiority of ThreTopK, we now address the calculation of \lambda(\boldsymbol{x}). In most cases, this requires numerical methods, but for f(x)=\min(1, e^x), we can find an analytical solution.

The logic is similar to Sparsemax. Without loss of generality, assume the components of \boldsymbol{x} are sorted in descending order: x_1 > x_2 > \cdots > x_n. Suppose we know that x_m \geq \lambda(\boldsymbol{x}) \geq x_{m+1}. Then: \begin{equation} k = \sum_{i=1}^n \min(1, e^{x_i - \lambda(\boldsymbol{x})}) = m + \sum_{i=m+1}^n e^{x_i - \lambda(\boldsymbol{x})} \end{equation} Solving for \lambda(\boldsymbol{x}): \begin{equation} \lambda(\boldsymbol{x})=\log\left(\sum_{i=m+1}^n e^{x_i}\right) - \log(k-m) \end{equation} From this, we see that when k=1, m must be 0, and ThreTopK becomes Softmax. When k > 1, we cannot determine m beforehand, so we iterate through m=0,1,\cdots,k-1 and find the \lambda(\boldsymbol{x}) that satisfies x_m \geq \lambda(\boldsymbol{x}) \geq x_{m+1}.

A simple reference implementation:

import numpy as np

def ThreTopK(x, k):
    x_sort = np.sort(x)
    x_lamb = np.logaddexp.accumulate(x_sort)[-k:] - np.log(np.arange(k) + 1)
    x_sort_shift = np.pad(x_sort[-k:][1:], (0, 1), constant_values=np.inf)
    lamb = x_lamb[(x_lamb <= x_sort_shift) & (x_lamb >= x_sort[-k:])]
    return np.clip(np.exp(x - lamb), 0, 1)

k, x = 10, np.random.randn(100)
ThreTopK(x, k)

General Results

As seen from the theory and code, ThreTopK with f(x)=\min(1, e^x) has almost no numerical stability issues and reduces to Softmax when k=1. However, \min(1, e^x) is not perfectly smooth (except when k=1 and \min is inactive), as it is non-differentiable at x=0. If this is a concern, we need a everywhere-differentiable f(x), such as \sigma(x).

Taking f(x)=\sigma(x) as an example, we cannot find an analytical solution for \lambda(\boldsymbol{x}). However, since \sigma(x) is monotonically increasing, the function: \begin{equation} F(\lambda)\triangleq \sum_{i=1}^n \sigma(x_i - \lambda) \end{equation} is monotonically decreasing with respect to \lambda. Thus, solving F(\lambda(\boldsymbol{x}))=k numerically is not difficult using bisection or Newton’s method. For bisection, it is clear that \lambda(\boldsymbol{x})\in[x_{\min} - \sigma^{-1}(k/n), x_{\max} - \sigma^{-1}(k/n)], where \sigma^{-1} is the inverse of \sigma.

import numpy as np

def sigmoid(x):
    y = np.exp(-np.abs(x))
    return np.where(x >= 0, 1, y) / (1 + y)

def sigmoid_inv(x):
    return np.log(x / (1 - x))

def ThreTopK(x, k, epsilon=1e-4):
    low = x.min() - sigmoid_inv(k / len(x))
    high = x.max() - sigmoid_inv(k / len(x))
    while high - low > epsilon:
        lamb = (low + high) / 2
        Z = sigmoid(x - lamb).sum()
        low, high = (low, lamb) if Z < k else (lamb, high)
    return sigmoid(x - lamb)

k, x = 10, np.random.randn(100)
ThreTopK(x, k)

Numerical calculation of \lambda(\boldsymbol{x}) is not the main difficulty; the real problem is that numerical methods often lose the gradient of \lambda(\boldsymbol{x}) with respect to \boldsymbol{x}, affecting end-to-end training. To fix this, we can manually calculate \nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x}) and define a custom backpropagation. Differentiating: \begin{equation} \sum_{i=1}^n \sigma(x_i - \lambda(\boldsymbol{x})) = k \end{equation} with respect to x_j, we get: \begin{equation} \sigma'(x_j - \lambda(\boldsymbol{x}))-\sum_{i=1}^n \sigma'(x_i - \lambda(\boldsymbol{x}))\frac{\partial\lambda(\boldsymbol{x})}{\partial x_j} = 0 \end{equation} Thus: \begin{equation} \frac{\partial\lambda(\boldsymbol{x})}{\partial x_j} = \frac{\sigma'(x_j - \lambda(\boldsymbol{x}))}{\sum\limits_{i=1}^n \sigma'(x_i - \lambda(\boldsymbol{x}))} \end{equation} where \sigma' is the derivative of \sigma. With this expression, we can use the "stop gradient" (sg) trick to implement custom gradients: \begin{equation} \boldsymbol{x}\cdot\text{sg}[\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x})] + \text{sg}[\lambda(\boldsymbol{x}) - \boldsymbol{x}\cdot\nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x})] \end{equation} This ensures the forward pass uses \lambda(\boldsymbol{x}) while the backward pass uses the specified \nabla_{\boldsymbol{x}}\lambda(\boldsymbol{x}).

Both Worlds

We have seen that f(x)=\min(1,e^x) has an analytical solution but is not globally smooth, while f(x)=\sigma(x) is smooth but requires numerical solving. Is there a choice that combines both? I found that the following f(x) is globally smooth and allows for an analytical solution for \lambda(\boldsymbol{x}): \begin{equation} f(x) = \left\{\begin{aligned}1 - e^{-x}/2,\quad x\geq 0 \\ e^x / 2,\quad x < 0\end{aligned}\right. \end{equation} This can also be written as f(x) = (1 - e^{-|x|})\text{sign}(x)/2+1/2. This is an S-shaped function, and although it is piecewise, both the function and its derivative are continuous at x=0, making it sufficiently smooth.

Assuming x_1 > x_2 > \cdots > x_n and x_m \geq \lambda(\boldsymbol{x}) \geq x_{m+1}: \begin{equation} \begin{aligned} k =&\, \sum_{i=1}^m (1 - e^{-(x_i - \lambda(\boldsymbol{x}))}/2) + \sum_{i=m+1}^n e^{x_i - \lambda(\boldsymbol{x})}/2 \\ =&\, m - \frac{1}{2}e^{\lambda(\boldsymbol{x})}\sum_{i=1}^m e^{-x_i} + \frac{1}{2}e^{-\lambda(\boldsymbol{x})}\sum_{i=m+1}^n e^{x_i} \end{aligned} \end{equation} Solving for \lambda(\boldsymbol{x}): \begin{equation} \lambda(\boldsymbol{x})=\log\sum_{i=m+1}^n e^{x_i} - \log\left(\sqrt{(k-m)^2 + \left(\sum_{i=1}^m e^{-x_i}\right)\left(\sum_{i=m+1}^n e^{x_i}\right)}+(k-m)\right) \end{equation} We then iterate through m=0,1,\cdots,n-1 to find the valid \lambda(\boldsymbol{x}). One can also prove that for k=1, this ThreTopK also reduces to Softmax.

Reference implementation:

import numpy as np

def ThreTopK(x, k):
    x_sort = np.sort(x)
    lse1 = np.logaddexp.accumulate(x_sort)
    lse2 = np.pad(np.logaddexp.accumulate(-x_sort[::-1])[::-1], (0, 1), constant_values=-np.inf)[1:]
    m = np.arange(len(x) - 1, -1, -1)
    x_lamb = lse1 - np.log(np.sqrt((k - m)**2 + np.exp(lse1 + lse2)) + (k - m))
    x_sort_shift = np.pad(x_sort[1:], (0, 1), constant_values=np.inf)
    lamb = x_lamb[(x_lamb <= x_sort_shift) & (x_lamb >= x_sort)]
    return (1 - np.exp(-np.abs(x - lamb))) * np.sign(x - lamb) * 0.5 + 0.5

k, x = 10, np.random.randn(100)
ThreTopK(x, k)

Summary

This article explored the problem of smooth approximations for the Top-k operator, which is a generalization of smooth approximations for Top-1 like Softmax. We proposed three construction strategies: iterative construction, gradient guidance, and undetermined constants, and analyzed their respective advantages and disadvantages.

Original URL: https://kexue.fm/archives/10373