This post is part of a series:
In this post, we are going to see how we need to adjust our decision tree algorithm from the previous post so that we can do a regression task with it.
Classification vs. Regression
Therefor, let’s look again at the flow chart for our algorithm and then also at the scatter plot depicting the Iris-flower data set.
See slide 1
By applying the algorithm to the data set, we found these decision boundaries:
See slide 2
And this whole process was a classification task because the label that we wanted to predict was the species of the flower. So, it was a categorical variable with 3 distinct classes, Iris-setosa, Iris-versicolor and Iris-virginica.
So now, to do a regression task we first need a label that is actually a continues variable.
See slide 3
And here, I simply created such a label, namely a fictional price that we could sell the flowers for. The price ranges between $0 and $10. And the darker the color of a dot, the higher the price that we could sell that flower for.
And, as you can see, there seem to be two distinct groups of flowers. The dots in the left cluster have a brighter color. So, they generally have a somewhat lower price. And the dots in the right cluster are darker and therefore their price is generally higher. So accordingly, to be able to make good predictions about new, unknown flowers, the decision boundary should ideally be between these two clusters, for example, at a petal width of 0.8.
See slide 4
So, let’s say that the algorithm has found this boundary as the best possible split. How do we now make predictions about new, unknown flowers that fall into one of the two areas? So, what price do we predict for a flower that falls into the left area? And what price do we predict for a flower that falls into the right area?
Using the Mean as the Prediction
So, in terms of our flow chart, we are thinking about how we would execute the “Classify” step. And since we are now distinguishing between classification and regression, “Classify” is probably not the best name for this step. So, let’s rename it to “Create leaf” because that seems to be the convention in the literature.
See slide 5
And just as a side note, I think it is called "leaf" because, along the analogy of a real tree, if you start at the root and move along the tree, then the last thing that you reach are the leaves.
See slide 6
So, I guess that’s why the parts of the decision tree, where we then make our prediction, are called "leaves".
Okay, so originally, for the classification task, we simply created a leaf based on which class appeared most often in a split area. This, however, is not possible anymore for a regression task, because all the flowers in a split area might have a different price. They, so to say, all belong to a different class. So, creating a leaf by checking which class appears most often doesn’t really make sense anymore.
And even if there are some flowers that happen to have the same price, that doesn't necessarily mean that they are representative for all the flowers in a split area. So, how do we then create a leaf when doing a regression task?
Well, a simple approach is to just calculate the average price of the flowers in a split area and use that as the prediction for new, unknown flowers that fall into that area.
So, that’s the first thing that we need to change about the algorithm.
See slide 7
Namely, how we create a leaf once we have found the best split. And this begs the question: How do we find that split in the first place?
Using Entropy for a Regression Task
Well, just like it was the case with the classification task, we are first going to determine all the potential splits.
See slide 8
So, we don’t need to change anything about the “Potential splits?”-step of the algorithm.
And then, in the next step, we need to be able to decide which of these splits is the best to split the data on. And for the classification task, the metric that we used to find the best split was called entropy. And I have a whole post where I explain how entropy works, but let me here just give you a quick recap.
See slide 9
Namely, what the entropy measures is, so to say, the impurity of the data in a split area. So, for a given number of classes, for example 2, the entropy is maximized when both classes have the same proportion or probability (0.5 in this case).
See slide 10
And this maximum value for the entropy, where each class has the same proportion, gets higher and higher with an increasing number of classes.
See slide 11
So, if you have 3 classes, then each class has a probability of one third and the maximum entropy is 1.58. If you have 4 classes, then each class has a probability of one fourth and the maximum entropy is 2.00 and so on.
So, if you wanted to summarize the behavior of the entropy, you could say: the more classes there are and the more evenly they are mixed up, the higher the entropy is going to be.
Okay, so keeping that in mind, let’s now think about what would happen if we would use the entropy to find the best split for a regression task. Therefore, let’s assume for now that all the flowers have a different price. So, each flower, so to say, belongs to a different class and each class therefore has the same proportion. So, we are always looking at maximum entropy values.
Okay, so to understand what happens when we use the entropy for a regression task, let’s only look at some edge cases of the splits. Namely, we are going to look at the most outer splits. And actually, we are going to look at the splits that are even one step further outside.
See slide 12
In terms of how to best split the data, they don’t really make sense because they split the data in such a way that all the dots are on one side of the split whereas on the respective other side of the split there are no dots. But considering those two splits will help us to understand what is going on if we use the entropy when doing a regression task.
And then, we are going to look at one more edge case. Namely, the split that cuts the data in half.
See slide 13
So, for this split, there is an equal number of dots on each side.
Okay so now, let’s start thinking about how the entropy behaves for these three splits. And we are going to start with the one on the far left.
See slide 14
Here, all of the flowers are on the right side. So, the entropy of the right side is going to be very high.
See slide 15
In this case, it is 7.04. And on the left side, there are no flowers, so the entropy is going to be zero.
See slide 16
Okay so then, since the entropy gives us only information about one side of the split, we calculate the overall entropy to get an understanding of the whole situation created by the split.
See slide 17
And the overall entropy is just a weighted average of the entropy values of the two split areas.
So, since all the dots are on the right side of this split, the weight for the right side is going to be one. And the weight of the left side is zero. So, the overall entropy for this split is going to be completely determined by the right side and it is also 7.04.
See slide 18
The same kind of reasoning applies to the split on the far right.
See slide 19
Only here, all the flowers are on the left side. So, the overall entropy is going to be 7.04 as well.
And now, let’s look at the split that cuts the data in half.
See slide 20
Here, there is an equal number of dots on both sides. Therefore, the entropy is the same on both sides.
See slide 21
In this case, it is 6.04. And since the overall entropy is simply a weighted average of those two entropy values, it is going to be 6.04 as well.
See slide 22
So, the overall entropy for this split is basically just based on the entropy of one side.
So now, let’s compare the three edge cases.
See slide 23
And here we can see that the overall entropy is lowest for the middle split. And this necessarily must be the case. And that’s because for the outer splits, the overall entropy is based on the entropy that you get when you have to consider all the dots. Whereas for the middle split, the overall entropy is basically just based on the entropy that you get when you have to consider only half of the dots. Hence, it consequently must be lower.
So now, let’s think about how the overall entropy behaves for our potential splits.
See slide 24
Namely, what we should see is that the overall entropy is highest at the most outer splits. And then, as we get closer and closer to the split that cuts the data in half, the overall entropy should decrease.
See slide 25
And that’s exactly the behavior that we can see.
So, for the special case where all the flowers have a different price, using the entropy for a regression task results in the fact that the best split turns out to be the split that simply cuts the data in half. So, it doesn’t help us to detect any patterns in the data.
And even if there would be some flowers that have the same price, this wouldn’t really change anything about this general behavior. We certainly wouldn’t find that the best split is actually the split that separates the two clusters of flowers.
See slide 26
So clearly, we can’t use entropy for a regression task, and we need a different metric to decide which of the splits is the best one. So, let’s now think about what that metric should actually measure. So, what it should tell us about the data.
Using MSE for a Regression Task
Namely, what we want is a metric that tells us something about how similar the flowers are to each other in a split area. And it shouldn’t measure if they are actually the same, like entropy does.
And a metric that does exactly that is for example the mean squared error (MSE).
See slide 27
I also already have a post where I describe how that works. It was, however, made in the context of deep learning. So, the formula, that we are going to use here, is slightly different. But if you aren’t familiar with the MSE yet, you can nonetheless check out the post to understand how it works.
But basically, what you do in the formula is that you calculate the error. So, you subtract the prediction from the actual value. So, for example, you take the price of the flower at a petal width of 0.6. And from that price you subtract the average price of all the flowers on the left side of the split. And then, you simply square that error.
And this, you do for all the examples “i” (so for all the dots on the left side of the split). And then, you calculate the mean of all those squared errors. Hence, the name “mean squared error”.
And then, as before, we calculate the overall MSE which is again just the weighted average of the MSEs of both sides of the splits.
See slide 28
So, that’s our new metric. And now, let’s think about how that behaves for different splits.
Namely, when we look at the split at a petal width of 0.8, then we can see that the flowers in the right area are relatively similar to each other in terms of their price. They all have a relatively high price. So, the average price of these flowers is going to be a good representation for those flowers in general and not one flower will be very far away from that average. So, the individual errors are going to be relatively low and therefore the MSE is also going to be relatively low.
And the same kind of reasoning applies to the left side, only here we have flowers with a generally low price. So, the MSE for the left side is going to be relatively low as well. And therefore, the overall MSE is also going to be low.
But now, let’s move the split more to the left.
See slide 29
With this split, one of the flowers with a generally lower price (the brighter dots) will be included in the right split area. Hence, the average price of all the flowers in the right split area is going to get somewhat lower. And because of that, all the flowers with a generally higher price (the darker dots) are now a little bit further away from the average. So, the errors are going to be a little bit bigger. And additionally, there is this one flower with a relatively low price that is very far away from the average. And therefore, the MSE on the right side is going to be somewhat higher compared to the split before.
On the left side, the MSE doesn’t really change because there are still only flowers with a generally low price. So, on the right side the MSE is going to increase and on the left side it is going to stay the same. Therefore, the overall MSE is going to be somewhat higher.
And if we would have moved the split more to the right, then we would have seen the same behavior.
See slide 30
Only here, the MSE on the left side would have increased and the MSE on the right side would have stayed the same.
So, as you can see, the MSE really tells us something about how similar the flowers in a split area are to each other. The lower it is, the more similar the flowers are to each other.
Therefore, we should find that the split at a petal width of 0.8 results in the lowest overall MSE.
See slide 31
And, as we can see, that is really the case.
So, the second thing that we need to change about the algorithm is how we determine the best split.
See slide 32
When doing a classification task, we find the best split by calculating, for example, the overall entropy. And when doing a regression task, we find the best split by calculating, for example, the overall MSE.
So, let’s rename the “Lowest Overall Entropy?”-step to “Determine best split” so that it can represent both cases.
See slide 33
And while we are at it, let’s also rename the “Potential splits?”-step to “Determine potential splits”.
See slide 34
So, those are the only two differences between applying the decision tree algorithm to a classification task versus applying it to a regression task, namely how we create a leaf and how we determine the best split.
For the regression task, however, you have to keep in mind that the “Data pure?”-step would only say that the data is pure when there is actually just one flower left or when a couple of flowers happen to have the same price. In that case, however, the decision tree would have an extreme high number of layers and it would be overfitting.
So, to get around this problem we need another stopping condition (or in the context of recursion: base case) for the algorithm. And we can, for example, use the one that we have seen in the first post of this series where there needed to be a minimum number of examples.
See slide 35
So now, since both of those stopping conditions are theoretically applicable to classification or regression task, let’s put them into one step. Therefor, let’s rename the “Data pure?”-step to “Base cases?”.
See slide 36
So, this is how we need to adjust the decision tree algorithm from the previous post so that we are able to do a regression task with it. And if you want to see how to implement these changes in code, you can check out this post.