class: center, middle, inverse, title-slide # Decision trees - Regression tree prediction and pruning ### Dr. D’Agostino McGowan --- layout: true <div class="my-footer"> <span> Dr. Lucy D'Agostino McGowan <i>adapted from slides by Hastie & Tibshirani</i> </span> </div> --- ## Decision tree predictions * Predict the response for a _test observation_ using the mean of the _training observations_ in the region that the _test observation_ belongs to -- .question[ What could potentially go wrong with what we have described so far? ] -- * The process may produce good predictions on the _training_ set but is likely to **overfit!** --- ## Pruning a tree _Do you love the tree puns? I DO!_ * A smaller tree (with fewer splits, that is fewer regions `\(R_1,\dots, R_j\)` ) may lead to **lower variance** and better interpretation at the cost of a little **bias** -- * A good strategy is to _grow_ a very large tree, `\(T_0\)`, and then **prune** it back to obtain a **subtree** -- * For this, we use **cost complexity pruning** (also known as **weakest link** 🔗 **pruning**) -- * Consider a sequence of trees indexed by a nonnegative tuning parameter, `\(\alpha\)`. For each `\(\alpha\)` there is a subtree `\(T \subset T_0\)` such that `$$\sum_{m=1}^{|T|}\sum_{i:x_i\in R_m}(y_i-\hat{y}_{R_m})^2+\alpha|T|$$` is as small as possible. --- ## Pruning `$$\sum_{m=1}^{|T|}\sum_{i:x_i\in R_m}(y_i-\hat{y}_{R_m})^2+\alpha|T|$$` * `\(|T|\)` indicates the number of terminal nodes of the tree `\(T\)` -- * `\(R_m\)` is the box (the subset of the predictor space) corresponding to the `\(m\)`th terminal node -- * `\(\hat{y}_{R_m}\)` is the mean of the training observations in `\(R_m\)` --- ## Choosing the best subtree * The _tuning parameter_, `\(\alpha\)`, controls the trade-off between the subtree's _complexity_ and its _fit_ to the training data -- .question[ How do you think you could select `\(\alpha\)`? ] -- * You can select an optimal value, `\(\hat{\alpha}\)` using **cross-validation**! -- * Then return to the full dataset and obtain the subtree using `\(\hat{\alpha}\)` --- ## Summary regression tree algorithm * Use **recursive binary splitting** to grow a large tree on the training data, stop when you reach some stopping criteria -- * Apply **cost complexity pruning** to the larger tree to obtain a sequence of best subtrees, as a function of `\(\alpha\)` -- * Use K-fold cross-validation to choose `\(\alpha\)`. Pick `\(\alpha\)` to minimize the average error -- * Return the subtree that corresponds to the chosen `\(\alpha\)` ---