Top-k Sampling Beyond Gumbel Top-k
Table of Contents
Sampling without replacement is simple—until probabilities are unequal. In machine learning, Gumbel top-k is a common approach. While fast and easy to implement, it has a subtle but important flaw: its marginal probabilities are generally inexact and intractable. In this post, we revisit Gumbel top-k and turn to sampling theory to explore principled alternatives.
Weighted random sampling
Before we get to Gumbel top-k, let us consider what is perhaps the most intuitive approach to sampling without replacement. That is, to sample sequentially, or draw-by-draw. Given a vector of probabilities,
\[\begin{align} \boldsymbol p \in [0,1]^n, && \sum_{i=1}^n p_i = k, \nonumber \end{align}\]we draw categorical samples sequentially, removing the already-sampled categories and renormalizing after each draw.
def weighted_random_sampling(p):
k = p.sum().round().int()
x = torch.zeros_like(p)
for _ in range(k):
probs = (1 - x) * p
probs = probs / probs.sum()
idx = torch.multinomial(probs, 1)
x.scatter_(0, idx, 1)
return x
All implementations in this blog post are kept as simple as possible to emphasize readability. We also assume that $\sum_{i=1}^n p_i = k$ up to numerical precision, hence the rounding. The algorithm’s time-complexity is $\mathcal{O}(nk)$, which may be too slow in some settings. This is precisely the problem that Gumbel top-k solves.
Gumbel top-k
Gumbel top-k [1] is an order sampling procedure, which means it samples a ranking variable and picks the k largest of those.
def gumbel_topk(p):
k = p.sum().round().int()
u = torch.rand_like(p).clip(1e-9, 1 - 1e-9)
g = -torch.log(-torch.log(u)) # Gumbel(0, 1)
idx = torch.topk(p.log() + g, k).indices
x = torch.zeros_like(p).scatter_(0, idx, 1)
return x
This type of procedure is efficient, since we avoid the sequential processing of a draw-by-draw procedure like weighted random sampling. Interestingly, both algorithms produce the same distribution over k-hot vectors, as noted in a footnote of the original paper on page 2 [1]. This is great, since it gives us a much faster equivalent to weighted_random_sampling. Unfortunately, this distribution has intractable marginals.
Other parametrizations
The parametrization of Gumbel top-k can sometimes be confusing. It is common to see it implemented with signatures such as gumbel_topk(w, k), taking real-valued weights and k as an explicit parameter. In this case, the weights (or logits) actually represent log probabilities. Note that the implementation above takes probabilities and adds Gumbel noise to p.log() instead.
In the original paper, the authors note: “This does not mean that the inclusion probability of element i is proportional to $p_i$: if we sample $k = n$ elements all elements are included with probability 1” [1].
For this reason, we additionally impose the constraint $\sum_{i=1}^n p_i = k$. With this parametrization, if $k = n$, then we must have $p_i = 1, \forall i$. This constraint is the canonical parametrization in sampling theory, where the parameters correspond directly to desired marginal probabilities. Using it makes it easier to reason about marginals and compare different algorithms.
Marginal probabilities
Understanding a distribution over all $n \choose k$ possible k-hot vectors $\boldsymbol x$ is challenging. We often care most about the marginal probabilities for each $x_i$,
\[\begin{align} \pi_i = \mathrm{Pr}(x_i = 1) = \sum_{\boldsymbol x} x_i\,p(\boldsymbol x). \nonumber \end{align}\]Equivalently, this is the expectation $\boldsymbol \pi = \mathbb{E}[\boldsymbol x]$. In sampling theory (which we will get to in the next section), these are known as inclusion probabilities, since $\mathrm{Pr}(x_i = 1)$ is the probability that a particular $x_i$ is included in the sample, i.e., is equal to one.
Higher order inclusion probabilities
We can define higher-order inclusion probabilities. For instance,
\[\begin{align} \pi_{ij} = \mathrm{Pr}(x_i = 1 \land x_j = 1) \nonumber \end{align}\]is the second-order inclusion probability, which form a matrix $\boldsymbol \Pi \in [0,1]^{n \times n}$ with the first-order inclusion probabilities on its diagonal. Adding a third index leads to third-order inclusion probabilities with a tensor of rank three, and so on.
If we are not careful, we might assume that weighted random sampling and Gumbel top-k produce distributions with marginals $p_i$. This would be a nice property. It is exactly how a categorical distribution works for one-hot vectors (the special case of $k = 1$). However, the actual marginals are inexact. In other words, the desired marginals $\boldsymbol p$ and actual marginals $\boldsymbol \pi_\text{WRS}$ are different,
\[\begin{align} \boldsymbol p = [0.2, 0.7, 0.3, 0.8], && \boldsymbol \pi_\text{WRS} = [0.24, 0.68, 0.35, 0.73]. \nonumber \end{align}\]Here, $\boldsymbol p$ denotes the parameters of the sampling algorithm, while $\boldsymbol \pi_\text{WRS}$ denotes the actual sampling probabilities. Worse yet, the actual marginals are intractable. Computing $\pi_\text{WRS}$ involves marginalizing over all possible k-hot vectors. As already mentioned,
\[\begin{align} \boldsymbol \pi_\text{WRS} = \boldsymbol \pi_\text{Gumbel top-$k$}. \nonumber \end{align}\]This is a substantial problem. Unknown marginals limit both interpretability and probabilistic modeling. Yves Tillé remarks: “What is the use of an algorithm when we cannot compute the inclusion probabilities?” [2].
Introducing sampling theory
Sampling theory is a field of mathematics and statistics concerned with selecting subsets of populations. A major application of these techniques is survey sampling. So what does this have to do with top-k sampling? In short, finite population sampling means drawing a sample of size k from a population of size n. With unequal probabilities, sampling without replacement is known as probability-proportional-to-size without replacement (PPSWOR) sampling, or πps sampling. In ML terms, πps sampling corresponds to sampling a k-hot vector where each element has a prescribed marginal probability. This is exactly the top-k sampling we are trying to do.
Let us see how these fundamental methods can help address our problem of inexact marginals, and implement three of them in PyTorch.
Brewer’s method
Brewer’s method [3] can be thought of as a corrected version weighted random sampling in the sense that it produces the desired marginals exactly. To implement it, we only need to modify a few lines of code:
def brewers_method(p):
k = p.sum().round().int()
x = torch.zeros_like(p)
for i in range(1, k + 1):
c = k - (p * x).sum()
probs = (1 - x) * p * (c - p) / (c - p * (k - i + 1))
probs = probs / probs.sum()
idx = torch.multinomial(probs, 1)
x.scatter_(0, idx, 1)
return x
Then, our desired and actual marginals match,
\[\begin{align} \boldsymbol p = [0.2, 0.7, 0.3, 0.8], && \boldsymbol \pi_\text{Brewer's} = [0.2, 0.7, 0.3, 0.8]. \nonumber \end{align}\]Like weighted random sampling, it is a draw-by-draw procedure. Unfortunately, we lost the efficiency of Gumbel top-k, especially for large k. Other than that, it is a nice option with exact marginals.
Conditional Poisson sampling
⚠️ Not to be confused with the Poisson distribution.
Conditional Poisson sampling [4] has a slightly confusing name. The name comes from another sampling algorithm, Poisson sampling. In machine learning terminology, Poisson sampling is simple independent Bernoulli sampling, so x = torch.bernoulli(p) in PyTorch. Poisson sampling does not guarantee that each sample is k-hot, only that the expected sum is k. Conditional Poisson sampling conditions on $\sum_{i=1}^n p_i = k$ so that every sample is k-hot.
def conditional_poisson_sampling(p, max_iter=1000):
k = p.sum()
for _ in range(max_iter):
x = torch.bernoulli(p)
if torch.isclose(x.sum(), k):
return x
This is a rejection sampling algorithm. The efficiency of rejection sampling depends on the acceptance rate $\mathrm{Pr}(\sum_{i=1}^n x_i = k)$. This sum of independent, but not identically distributed, Bernoulli variables follows a Poisson binomial distribution. For practical use, there are more efficient algorithms to sample from the same distribution [6].
For our running example, the marginal probabilities are the following,
\[\begin{align} \boldsymbol p = [0.2, 0.7, 0.3, 0.8], && \boldsymbol \pi_\text{CP} = [0.12, 0.79, 0.21, 0.88]. \nonumber \end{align}\]Unlike for weighted random sampling and Gumbel top-k, the $\pi_\text{CP}$ can be calculated recursively [5], which avoids enumerating all k-hot vectors. Still, the marginals are inexact. This can be fixed numerically by adjusting the parameters.
Adjusted CP
Adjusted Conditional Poisson sampling works by finding a set of parameters $\boldsymbol p$ that produce our desired marginals. For example, given the desired marginals
\[\begin{equation} \boldsymbol \pi^* = [0.2, 0.7, 0.3, 0.8], \nonumber \end{equation}\]a straightforward approach to find parameters that produce our desired marginal (described in [7], page 83) is to iterate
\[\begin{equation} \boldsymbol p^{(i+1)} = \boldsymbol p^{(i)} + (\boldsymbol \pi^{(i)} - \boldsymbol \pi^*). \nonumber \end{equation}\]If we set $\boldsymbol p^{(0)} = \boldsymbol \pi^{*}$ as our initial guess we get the following results:
\[\begin{align} \boldsymbol p^{(0)} = [0.20, 0.70, 0.30, 0.80], && \boldsymbol \pi^{(0)} = [0.12, 0.79, 0.21, 0.88], \nonumber \\ \boldsymbol p^{(1)} = [0.28, 0.61, 0.39, 0.72], && \boldsymbol \pi^{(1)} = [0.22, 0.65, 0.35, 0.78], \nonumber \\ \boldsymbol p^{(2)} = [0.26, 0.65, 0.35, 0.74], && \boldsymbol \pi^{(2)} = [0.19, 0.72, 0.28, 0.81], \nonumber \\ \boldsymbol p^{(3)} = [0.27, 0.63, 0.37, 0.73], && \boldsymbol \pi^{(3)} = [0.20, 0.69, 0.31, 0.80], \nonumber \\ \boldsymbol p^{(4)} = [0.27, 0.64, 0.36, 0.73], && \underbrace{\boldsymbol \pi^{(4)} = [0.20, 0.70, 0.30, 0.80].}_\text{Equals our desired $\pi^*$} \nonumber \\ \end{align}\]So, conditional Poisson sampling with parameters $p^{(4)}$ produces the desired inclusion probabilities $\boldsymbol \pi^{*}$ (up to numerical errors).
Maximum entropy
Among all distributions over k-hot vectors, conditional Poisson sampling has the highest entropy. Such distributions may be preferred based on the the principle of maximum entropy.
Intuitively, this is the maximum entropy distribution because it is simply an independent Bernoulli distribution (which has the maximum entropy among distributions over binary vectors) conditioned on the constraint that samples are k-hot. Practically speaking, many other top-k sampling methods also produce high entropies [8].
“[…] conditional Poisson sampling is probably the best solution to the problem of sampling with unequal probabilities, although one can object that other procedures provide very similar results”, Yves Tillé [7] (preface). With this in mind, let us look at one such algorithm that can produce similar results.
Pareto sampling
⚠️ Not to be confused with the Pareto distribution.
Pareto sampling [9] is an order sampling algorithm, like Gumbel top-k. It is as fast as Gumbel top-k while implementing a more tractable distribution. The name likely stems from the fact that the noise used in the ranking variables is Lomax-distributed,
\[\begin{align} U \sim \mathrm{Uniform}(0, 1) \implies \frac{U}{1 - U} \sim \mathrm{Lomax}(1, 1), \nonumber \end{align}\]which is a type II Pareto distribution. It can also be written as $1 + \frac{U}{1 - U} \sim \mathrm{Pareto}(1, 1)$.
def pareto(p, heuristic=False):
k = p.sum().round().int()
u = torch.rand_like(p).clip(1e-9, 1 - 1e-9)
q = (u / (1 - u)) / (p / (1 - p))
if heuristic:
d = (p * (1 - p)).sum()
q = q * torch.exp(p * (1 - p) * (p - 0.5) / d**2)
idx = torch.topk(-q, k).indices
x = torch.zeros_like(p).scatter_(0, idx, 1)
return x
Using Pareto sampling, we get the following marginals,
\[\begin{align} \boldsymbol p = [0.2, 0.7, 0.3, 0.8], && \hat{\boldsymbol \pi}_\text{Pareto} = [0.18, 0.73, 0.27, 0.82]. \nonumber \end{align}\]Here, $\hat{\boldsymbol \pi}_\text{Pareto}$ was computed empirically using 1’000’000 samples drawn using the pareto function. Like for conditional Poisson sampling, it is possible to compute the inclusion probabilities of Pareto sampling and find adjusted parameters to sample with desired marginals.
Heuristic adjustment
An easy way to get marginals closer to $\boldsymbol p$ is to use the following adjustment [10]:
\[\begin{align} \hat{q}_i = \exp\left(\frac{p_i(1 - p_i)(p_i - \frac{1}{2})}{\left[\sum_{j=1}^n p_j(1 - p_j)\right]^2}\right) q_i. \nonumber \end{align}\]Unlike the numerical adjustment described for conditional Poisson sampling, this is closed form. However, it does not give an exact adjustment.
\[\begin{align} \boldsymbol p = [0.2, 0.7, 0.3, 0.8], && \hat{\boldsymbol \pi}_\text{Heuristic} = [0.20, 0.71, 0.29, 0.80]. \nonumber \end{align}\]Here, $\hat{\boldsymbol \pi}_\text{Heuristic}$ was computed empirically using 1’000’000 samples drawn using the pareto function with heuristic=True.
Further reading
This blog post was inspired by Tillé’s article, Remarks on some misconceptions about unequal probability sampling without replacement [2].
More sampling algorithms – Tillé’s book [6] presents many other sampling algorithms, such as Sampford sampling, systematic sampling, and the pivotal method.
Computing $\boldsymbol p$ from weights – Given some real-valued weights, how do we (differentiably) compute marginal probabilities that sum to k? The answer is to use a top-k relaxation. See my recent workshop paper Differentiable Top-k: From One-Hot to k-Hot [11].
Differentiable sampling – There are multiple papers on gradient estimates for top-k sampling using techniques similar to those for categorical and Bernoulli sampling: REINFORCE, straight-through, Gumbel-softmax, etc. See my recent ICLR [12] and workshop [11] papers and the references therein.
References
[1] Wouter Kool, Herke van Hoof, Max Welling, Stochastic Beams and Where To Find Them: The Gumbel-Top-k Trick for Sampling Sequences Without Replacement, ICML, 2019.
[2] Yves Tillé, Remarks on some misconceptions about unequal probability sampling without replacement, Computer Science Review, 2023.
[3] K.R.W. Brewer, A Simple Procedure for Sampling πpswor, Australian Journal of Statistics, 1975.
[4] Jaroslav Hájek, Asymptotic Theory of Rejective Sampling with Varying Probabilities from a Finite Population, The Annals of Mathematical Statistics, 1964.
[5] Xiang-Hui Chen, Arthur P. Dempster, Jun S. Liu, Weighted Finite Population Sampling to Maximize Entropy, Biometrika, 1994.
[6] Lennart Bondesson, Imbi Traat, Anders Lundqvist, Pareto Sampling versus Sampford and Conditional Poisson Sampling, Scandinavian Journal of Statistics, 2006.
[7] Yves Tillé, Sampling Algorithms, Springer, 2006.
[8] Anton Grafström, Entropy of unequal probability sampling designs, Statistical Methodology, 2010.
[9] Bengt Rosén, On sampling with probability proportional to size, Journal of Statistical Planning and Inference, 1997.
[10] Anders Lundquist, Contributions to the Theory of Unequal Probability Sampling, Doctoral Thesis, Umeå University, 2009.
[11] Klas Wijk, Ricardo Vinuesa, Hossein Azizpour, Differentiable Top-k: From One-Hot to k-Hot, EurIPS 2025 Workshop on Differentiable Systems and Scientific Machine Learning 2025.
[12] Klas Wijk, Ricardo Vinuesa, Hossein Azizpour, SFESS: Score Function Estimators for k-Subset Sampling, ICLR, 2025.
@article{wijk2026top-k-sampling-beyond-gumbel-top-k,
title = {Top-k Sampling Beyond Gumbel Top-k},
author = {Wijk, Klas},
year = {2026},
month = {Jan},
url = {https://klaswijk.github.io/blog/2026/topk_sampling_beyond_gumbel_topk/}
}