Fixing Random, part 11
Last time on FAIC we drew a line in the sand and required that the predicates and projections used by Where and Select on discrete distributions be "pure" functions: they must complete normally, produce no side effects, consume no external state, and never change their behaviour. They must produce the same result when called with the same argument, every time. If we make these restrictions then we can get some big performance wins out of Where and Select. Let's see how.
The biggest problem we face is that possibly-long-running loop in the Where; basically we are rejection-sampling the distribution, and we know that can take a long time. Is there a way to directly produce a new distribution that can be efficiently sampled?
Of course there is, and it was a little silly of us to miss it. Let's make a helper method:
public static IDiscreteDistribution<T> ToWeighted<T>(
this IEnumerable<T> items,
IEnumerable<int> weights)
{
var list = items.ToList();
return WeightedInteger.Distribution(weights)
.Select(i => list[i]);
}
There's an additional helper method I'm going to need in a couple of episodes, so let's just make it now:
public static IDiscreteDistribution<T> ToWeighted<T>(
this IEnumerable<T> items,
params int[] weights) =>
items.ToWeighted((IEnumerable<int>)weights);
And now we can delete our Conditioned class altogether, and replace it with:
public static IDiscreteDistribution<T> Where<T>(
this IDiscreteDistribution<T> d,
Func<T, bool> predicate)
{
var s = d.Support().Where(predicate).ToList();
return s.ToWeighted(s.Select(t => d.Weight(t)));
}
Recall that the WeightedInteger factory will throw if the support is empty, and return a Singleton or Bernoulli if it size one or two; we don't have to worry about that. No more maybe-long-running loop! We do at most two underlying samples per call to Sample now.
Exercise: We're doing probabilistic workflows here; it seems strange that we are either 100% rejecting or 100% accepting in Where. Can you write an optimized implementation of this method?
public static IDiscreteDistribution<T> Where<T>(
this IDiscreteDistribution<T> d,
Func<T, IDiscreteDistribution<bool>> predicate)
That is, we accept each T with a probability distribution given by a probabilistic predicate.
This one will be easier to implement once we have the gear we're going to develop a few episodes from now, but I mention it now to get you thinking about the problem.
That takes care of optimizing Where. What about optimizing Select?
Of course we can do the same trick. We just have to take into account the fact that the projection might cause some members of the underlying support to "merge" their weights:
public static IDiscreteDistribution<R> Select<A, R>(
this IDiscreteDistribution<A> d,
Func<A, R> projection)
{
var dict = d.Support()
.GroupBy(projection, a => d.Weight(a))
.ToDictionary(g => g.Key, g => g.Sum());
var rs = dict.Keys.ToList();
return Projected<int, R>.Distribution(
WeightedInteger.Distribution(
rs.Select(r => dict[r])),
i => rs[i]);
}
That is: we compute the new support, and the weights for each element of it. Now we can construct a weighted integer distribution that chooses an offset into the support.
Exercise: Why did I not write the far more concise:
return rs.ToWeighted(rs.Select(r => dict[r])));
?
Let's think for a minute about whether this really does what we want. Suppose for example we do a projection on a Singleton:
- We'll end up with a single-item dictionary with some non-zero weight.
- The call to the weighted integer factory will return a Singleton<int> that always returns zero.
- We'll build a projection on top of that, and the projection factory will detect that the support is of size one, and return a Singleton<T> with the projected value.
Though we've gone through a bunch of unnecessary object constructions, we end up with what we want.
Furthermore, suppose we have a projection, and we do a Select on that: we avoid the problem we have in LINQ-to-objects, where we build a projection on top of a projection and end up with multiple objects representing the workflow.
Aside: That last sentence is an oversimplification; in LINQ-to-objects there is an optimization for this scenario; we detect Select-on-Select (and Select-on-Where and so on) and build up a single object that represents the combined projection, but that single object still calls all the delegates. So in that sense there are still many objects in the workflow; there's just a single object coordinating all the delegates.
In this scenario we spend time up front calling the projection on every member of the support once, so we don't need to ever do it later.
Again, I'm not trying make the series of factories here the most efficient it could be; we're creating a lot of garbage. Were I building this sort of system for industrial use, I'd be more aggressive about taking early outs that prevent this sort of extra allocation. What I'm trying to illustrate here is that we can use the rules of probability (and the fact that we have pure predicates and projections) to produce distribution objects that give the correct results.
Aside: Of course, none of this fixes the original problems with weighted integer: that in the original constructor, we do not optimize away "all weights zero except one" or "trailing zeros". Those improvements are still left as exercises.
Next time on FAIC: we've seen that we can use Where" to filter a distribution to make a conditioned distribution; we'll look at a more rich and complex way to represent conditional probabilities, and discover a not-so-surprising fact about our distribution interface.