What Counting Jelly Beans Can Teach Us About Machine Learning

What Counting Jelly Beans Can Teach Us About Machine Learning

Remember that old carnival game, the one where you attempt to guess the number of jelly beans in a jar? While it often took some combination of luck and skill for any single person to accurately guess the correct number, it turns out that by averaging all of the guesses of a wide variety of people together, the averaged answer is surprisingly close to the correct response.

This phenomenon is an example of what’s known as “the wisdom of the crowd,” a modeling strategy frequently used in machine learning. Given that you have a diverse enough number of perspectives—each of which must have some measure of signal, but not be correlated to any other perspective (so errors tend to be symmetrically distributed around the truth)—as well as a suitable way of aggregating those perspectives (like averaging), you’ll find that in the results of that aggregation, the “rightness stacks up” while the errors tend to cancel each other out.

In the case of the jelly bean example, this means you must have a lot of people submit guesses (large number of perspectives), they’re all looking at the same jar of jelly beans (must have some measure of signal), and those people can’t talk to each other about their guesses (perspectives are not otherwise correlated). It’s fine if some people just take a quick glance at the jar, while others study it for a long time, or if someone submits a wildly incorrect guess—either over or under—as a joke.

The thing is, all of the answers are centered on the same truth. If someone takes a step mostly in the right direction and a bit in the wrong, then the right answers stack up and the wrongness cancels out. Since there’s a balanced chance of someone submitting an underestimate as an overestimate as more and more guesses are submitted, even the large missed will tend to cancel each other when averaging over the whole population.

The wisdom of the crowd is the backbone of random forest modeling—one of, if not the most popular machine learning algorithms. Random forest modeling works by training hundreds of “weaker” machine learning algorithms, called decision trees.

machine learning: decision tree

A decision tree will run through the available features of the data it is given, using the patterns it sees to establish a series of yes-or-no questions strategically designed to sort the data points according to a particular question (e.g is this kind of customer likely to renew a contract?). In the case of the jelly beans, this could be thought of as analogous to a single person reasoning with questions such as “how large is the jar?” and “how large are the jelly beans?” But decision trees are hampered. They only are only given access to a subset of the total available features to design the series of yes and no questions. Moreover, because they consider the yes/no questions one at a time—in series, not in combination—as they carve you an answer, their ability to capture subtle interactions can suffer.

Random forest, on the other hand, uses the “wisdom of the crowd” of a collection of decision trees. Each tree takes a different perspective on the data (just like the individual carnival attendees) and comes to a decision, then those decisions are aggregated and averaged together. Because each tree is only looking at some of the features, no individual tree is going to be very great at predicting on its own what the right answer is for every data point. Many of the trees will primarily be looking at unimportant features—the so-called “noise” in the data.

However, in random forest we don’t ask just one tree for the answer, we ask an ensemble of hundreds of trees. The key here is that the trees that pay attention mostly to noise (irrelevant information) will tend to be right as much as they are wrong (their votes cancel each other out, similar to what happens to errors in the jelly bean example), but the trees that pay a little better attention to feature that are more predictive of the right answer will be right slightly more often than they are wrong. So if we create a forest with enough trees, the “bad trees” end up in a 50-50 split, and then the votes of the “better trees” get to break the tie.

machine learning: random forest

As in the jelly bean example, by polling a lot of guesses—even if none of them is especially accurate for all cases—and combining these guesses, we can get results that are remarkably accurate for a remarkable majority of the time.

Of course, in order for this to work, a random forest model must also meet some caveats. First off, the decision trees must be different enough from each other that they don’t all produce the same response. They must ask different questions and look at different pieces of the dataset. If every tree looked at 100% of the same features, they would all vote unanimously for each data point and have no advantage over a single decision tree. (This is what we refer to as overfitting—that is, building our model in a way that pays too much attention to the noise of a particular data set.) In this case, the forest would use too much of the data’s noise, and when you move to a new dataset, you’d see a huge gap in performance.

The opposite is also true: building decision trees that are too small. They don’t get a good enough picture of the data by paying attention to too few of the features. This would mean that none of the trees will have enough information to reliably capture the signal—even “weakly”—and so there will won’t be enough tie-breakers to make the ensemble reliable.

The trick, of course, is finding a nice balance between over and underfitting. This is called “hyper tuning” the parameters that shape what kind of model you are building. (Like a regularization parameter in logistic regression, or the width of a Gaussian in RBF kernel SVM.) In the case of random forest, the most important hyper parameter is m, which specifies the portion of the total features p that each tree takes a look at.

One of the reasons why random forest is so popular is that there are some very reliable rules of thumb to guide the selection of the choice of m depending on the kind of data one is working with, making random forest one of the easiest algorithms to get working well.

This blog post was written with contributions from Mike Tamir.

Have any further burning questions about random forest modeling?

Want more data science tutorials and content? Subscribe to our data science newsletter.

By submitting your information below, you agree to our Terms of Use and Privacy Policy.

Need assistance with this form?