Fixing Random, part 8
Last time on FAIC we sketched out an O(log n) algorithm for sampling from a weighted distribution of integers by implementing a discrete variation on the "inverse transform" method where the "inverse" step involves a binary search.
Can we do better than O(log n)? Yes - but we'll need an entirely different method. I'll sketch out two methods, and implement one this time, and the other in the next exciting episode.
Again, let's suppose we have some weights, say 10, 11, 5. We have three possible results, and the highest weight is 11. Let's construct a 3 by 11 rectangle that looks like our ideal histogram; I'll put dashes in to more clearly indicate the "spaces":
0|**********-1|***********2|*****------
Here's an algorithm for sampling from this distribution:
- Uniformly choose a random row and column in the rectangle; it's easy to choose a number from 0 to 2 for the row and a number from 0 to 10 for the column.
- If that point is a *, then the sample is the generated row number.
- If the point is a -, try again.
Basically, we're throwing darts at the rectangle, and the likelihood of hitting a valid point on particular row is proportional to the probability of that row.
In case that's not clear, let's work out the probabilities. To throw our dart, first we'll pick a row, uniformly, and then we'll pick a point in that row, uniformly. We have a 1/3 * 10/11 = 10/33 chance of hitting a star from row 0, an 1/3 * 11/11 = 11/33 chance of hitting a star from row 1, a 1/3 * 5/11 = 5/33 chance of hitting a star from row 2, and that leaves a 7/33 chance of going again. We will eventually pick a *, and the values generated will conform to the desired distribution.
Let's implement it. We'll throw away our previous attempt and start over. (No pun intended.)
private readonly List<IDistribution<int>> distributions;
private WeightedInteger(List<int> weights)
{
this.weights = weights;
this.distributions =
new List<IDistribution<int>>(weights.Count);
int max = weights.Max();
foreach(int w in weights)
distributions.Add(Bernoulli.Distribution(w, max - w));
}
All right, we have three distributions; in each, a zero is success and a one is failure. In our example of weights 10, 11, 5, the first distribution is "10 to 1 odds of success", the second is "always success", and the third is "5 to 6 odds of success". And now we sample in a loop until we succeed. We uniformly choose a distribution, and then sample from that distribution until we get success.
public int Sample()
{
var rows = SDU.Distribution(0, weights.Count - 1);
while (true)
{
int row = rows.Sample();
if (distributions[row].Sample() == 0)
return row;
}
}
We do two samples per loop iteration; how many iterations are we likely to do? In our example we'll get a result on the first iteration 26/33 of the time, because we have 26 hits out of 33 possibilities. That's 79% of the time. We get a result after one or two iterations 95% of the time. The vast majority of the time we are going to get a result in just a handful of iterations.
But now let's think again about pathological cases. Remember our distribution from last time that had 1000 weights: 1, 1, 1, 1, ", 1, 1001. Consider what that histogram looks like.
0|*------...--- (1 star, 1000 dashes) 1|*------...--- (1 star, 1000 dashes) [...] 998|*------...--- (1 star, 1000 dashes) 999|*******...*** (1001 stars, 0 dashes)
Our first example has 26 stars and 6 dashes. In our pathological example there will be 2000 stars and 999000 dashes, so the probability of exiting the loop on any particular iteration is about one in 500. We are typically going to loop hundreds of times in this scenario! This is far worse than our O(log n) option in pathological cases.
This algorithm is called "rejection sampling" (because we "reject" the samples that do not hit a *) and it works very well if all the weights are close to the maximum weight, but it works extremely poorly if there is a small number of high-weight outputs and a large number of low-weight outputs.
Fortunately there is a way to fix that problem. I'm going to reject rejection sampling, and move on to our third and final technique.
Let's re-draw our original histogram, but I'm going to make two changes. First, instead of stars I'm going to fill in the number sampled, and I'm going to make it a 33 by 3 rectangle, and triple the size of every row.
0|000000000000000000000000000000---1|1111111111111111111111111111111112|222222222222222------------------
Plainly this is logically no different; we could "throw a dart", and the number that we hit is the sample; if we hit a dash, we go again. But we still have the problem that 21 out of 99 times we're going to hit a dash.
My goal is to get rid of all the dashes, but I'm going to start by trying to get 7 dashes in each row. There are 21 dashes available, three rows, so that's seven in each row.
To achieve that, I'm going to first pick an "excessive" row (too many numbers, too few dashes) and "deficient row" (too few numbers, too many dashes) and move some of the numbers from the excessive row to the deficient row, such that the deficient row now has exactly seven dashes. For example, I'll move eleven of the 0s into the 2 row, and swap eleven of the 2 row's dashes into the 0 row.
0|0000000000000000000--------------1|1111111111111111111111111111111112|22222222222222200000000000-------
We've achieved our goal for the 2 row. Now we do the same thing again. I'm going to move seven of the 1s into the 0 row:
0|00000000000000000001111111-------1|11111111111111111111111111-------2|22222222222222200000000000-------
We've achieved our goal. And now we can get rid of the dashes without changing the distribution:
0|000000000000000000011111111|111111111111111111111111112|22222222222222200000000000
Now we can throw darts and never get any rejections!
The result is a graph where we can again, pick a column, and then from that column sample which result we want; there are only ever one or two possibilities in a column, so each column can be represented by a Bernoulli or Singleton distribution. Therefore we can sample from this distribution by doing two samples: a uniform integer to pick the row, and the second one from the distribution associated with that row.
Aside: You might wonder why I chose to move eleven 0s and then seven 1s. I didn't have to! I also could have moved eleven 1s into the 2 row, and then four 0s into the 1 row. Doesn't matter; either would work. As we'll see in the next episode, as long as you make progress towards a solution, you'll find a solution.
This algorithm is called the "alias method" because some portion of each row is "aliased" to another row. It's maybe not entirely clear how to implement this algorithm but it's quite straightforward if you are careful.
Next time on FAIC: The alias method implementation.