This post is part of a series:
In this post, we are going to cover how decision tree pruning works. So, first of all:
Why do we need to prune decision trees?
We need to prune decision trees because they tend to overfit the training data. To understand why that is, let’s look at a flow diagram of a basic decision tree algorithm (which we have derived in the previous three posts).
See slide 1
So, first we check if the data is pure. If it is, then we create a leaf and stop. If it isn’t, then we determine the potential splits. After that, we determine the best one of those splits and then we split the data accordingly. And finally, we repeat this process for both partitions of the data. And we keep doing that until all partitions are eventually pure and we therefore reach the “Stop”-field for all recursive calls of the decision tree algorithm.
And that’s exactly where the problem lies. Namely, the algorithm comes only to a stop when each partition of the data is pure. So, in case the different classes aren’t clearly separated from each other or in case there are outliers, the resulting decision tree will simply have too many layers, i.e. it will overfit the training data. And because of that, its ability to generalize well to new, unseen data will be diminished.
To see this problem in action, let’s look at some example data.
See slide 2
So here, there are two classes, namely there are “Orange” dots and “Blue” dots. So, we are going to do a classification task. And those two classes are clearly separated from each other. If x is smaller than 5, then the respective data point is “Orange”. Otherwise, it is “Blue”. So, if we apply our algorithm to this data, then we get a decision tree that looks like this:
See slide 3
And if we apply this decision tree to the testing data, then we can see that it is very good at predicting the different classes which means that it can generalize well to new, unseen data.
See slide 4
So, for the case where the different classes are clearly separated from each other (which is basically never the case), we certainly can create a decision tree with our algorithm that is not overfitting. But now, let’s add one “Orange” outlier to the training data (x=5.4; y=8.4).
See slide 5
In this case, our simple decision tree is still ideal since it captures the general distinction between the two classes. There just happens to be that one outlier in the training data which has a slightly larger x than what is usual for data points from the “Orange” class.
But because of this one outlier, the right partition of the data is still not pure (for the data points with an x bigger than 5, there are mostly “Blue” dots but also one “Orange” dot). So, the algorithm doesn’t come to a stop yet and it will keep splitting the data. Or in other words, it will create a deeper tree with more layers to separate out the outlier. And it could look something like this:
See slide 6
And looking at the testing data, we can see that this tree misclassifies some of the “Blue” data points as belonging to class “Orange” (the three blue dots in the small orange area). So, the decision tree is now overfitting the training data and it is not as good as it was before at generalizing to new, unseen data (and you can probably imagine that the situation gets even worse if there would be some more outliers).
To get around this problem, we need to prune the tree, i.e. we need to make sure that the tree doesn’t have too many layers. And there are two different types of pruning.
Pre-pruning: Min-Samples Approach
The first one is called pre-pruning. And here, we simply make sure that the tree doesn’t get too deep in the first place.
One way of doing that, for example, is to specify a minimum number of data points that need to present in the data.
See slide 7
If the number of data points is less than this minimum number, then we create a leaf even though the data isn’t pure yet. And in that case, the leaf is based on which class appears most often. Or if the different classes appear equally often, we simply pick one of the classes randomly.
So, in our case, if we set the number of minimum samples equal to 5, then we wouldn’t ask the last question in the tree (“y <= 8.4”). And that’s because at that point, there are only 4 data points left in the respective partition of the training data.
See slide 8
So, accordingly, we create a leaf instead of splitting the data again. And since there are 3 “Blue” dots and only one “Orange” dot, the leaf is going to be “Blue”.
See slide 9
And, as you can see, for data points with an x bigger than 5, the tree now always predicts “Blue”. So, essentially, this tree makes the same predictions as the tree that we got earlier when there was no outlier.
See slide 10
So, the tree is not overfitting anymore (despite the outlier). And it predicts all the test data points correctly again which means it generalizes better to new, unseen data.
Pre-pruning: Max-Depth Approach
So, that’s one way of implementing pre-pruning. Another way is to specify a maximum depth for the tree, i.e. a maximum number of layers that the tree should have.
See slide 11
In that case, we initiate a “Counter” at the start of the algorithm to keep track of how many layers have already been created. And this “Counter” gets increased by 1 every time the data gets split. And then, if the maximum depth is reached at some point (i.e. Counter == Max-Depth), we create a leaf, even if the data isn’t pure yet.
So, in our case, the basic decision algorithm without pre-pruning created a tree with 4 layers. Therefore, if we set the maximum depth to 3, then the last question (“y <= 8.4”) won’t be included in the tree. So, after the decision node “y <= 7.5”, the algorithm is going to create leaves. And to see what those leaves will be, let’s only depict the respective partitions of the data.
See slide 12
So, the data points with an y smaller or equal to 7.5 are all “Blue”. Therefore, the algorithm creates a “Blue” leaf. And above an y of 7.5, just as before, there are 3 “Blue” dots and 1 “Orange” dot. So, the leaf will be “Blue” as well.
See slide 13
So, we get the same tree as with the “Min-Samples” approach. And it again predicts all the test data points correctly.
And with that, we have now seen two ways of pre-pruning a decision tree. So now, for the sake of completeness, let’s include all the different base cases for the algorithm into our flow chart (without cluttering it up too much). Therefor, let’s rewrite it like this:
See slide 14
So, if one of the base cases applies, i.e. the data is pure, the maximum depth is reached or the minimum number of data points is not reached, then we create a leaf. Otherwise, we split the data.
Okay, so that’s the first type of pruning. The other type is called post-pruning. And here, in contrast to pre-pruning, we let the tree go deep. And then, afterwards, we prune it back. So, how does that work?
First of all, we need a training data set and a validation data set.
See slide 15
Then, as usual, we create a tree based on the training data.
See slide 16
And here, we don’t need to worry that it gets too deep, i.e. we create the tree without doing any pre-pruning (Side note: Technically, we could also do some pre-pruning. But in this simple scenario it is not necessary).
And now, we prune this tree. Therefor, we start with the decision node at the deepest layer (if there are several decision nodes at the deepest layer, we start with the leftmost one).
See slide 17
And what we now want to know is: Should we keep this decision node, or should we instead already create a leaf at this stage in the tree?
To answer this question, we obviously first need to know what the leaf should be. Or in other words, we need to know what leaf the decision tree algorithm would have created, if it wouldn’t have created the decision node.
This is pretty straightforward to do. Namely, we just filter our training data based on the questions that we need to ask to get to this stage in the decision tree. So, we are only looking at data points with an x that is bigger than 5.0 but smaller than 5.4 and an y that is bigger than 7.5.
See slide 18
And then, we create the leaf based on which class appears most often (or in case of doing a regression task: what the average value is). So, again, since there are 3 “Blue” dots and only 1 “Orange” dot, the leaf should be “Blue”.
See slide 19
Okay so now, in order to answer our question if we should keep the decision node or if we should instead create a leaf, we now need the validation data set. And what we want to know is: Is the decision node better at predicting the respective data points, or is the leaf better?
So, just like we did with the training data set, let’s filter the validation data set based on the questions in the tree.
See slide 20
And now, we can check how many errors the decision node is making and how many errors the leaf is making.
See slide 21
So, as you can see, the decision node predicts one “Blue” dot correctly, but it misclassifies three “Blue” dots. The leaf, on the other hand, predicts all data points correctly as being “Blue”.
So, since the leaf is better at making predictions than the decision node, we replace the decision node with the leaf.
See slide 22
And this procedure is called “Reduced Error Pruning”. So, if the number of errors for the leaf is smaller or equal to the number of errors of the decision node, we replace the decision node with the leaf. Otherwise, we keep the decision node.
Side note: If we are doing a regression task, then we check if the mean squared error of the leaf is the same or lower than the mean squared error of the decision node.
So now, we simply repeat this process for all the decision nodes in the tree.
See slides 23-28
And, as you can see, after pruning the tree, it again looks like the tree that was originally created when there was no outlier in the training data (see slide 4). So, again, it is not overfitting anymore and predicts all data points of the testing data set correctly.
See slide 29
So, this is how post-pruning works. And if you want to see how to implement this in code, you can check out this post.