+ - 0:00:00
Notes for current slide
Notes for next slide

Decision trees - Regression tree example

Dr. D’Agostino McGowan

1 / 13

The baseball example

2 / 13

1. Randomly divide the data in half, 132 training observations, 131 testing

set.seed(77)
baseball_split <- initial_split(baseball, prop = 0.5)
baseball_train <- training(baseball_split)
3 / 13

2. Create cross-validation object for 6-fold cross validation

baseball_cv <- vfold_cv(baseball_train, v = 6)
4 / 13

3. Create a model specification that tunes based on complexity, α

tree_spec <- decision_tree(
cost_complexity = tune(),
tree_depth = 10,
mode = "regression") %>%
set_engine("rpart")
5 / 13

3. Create a model specification that tunes based on complexity, α

tree_spec <- decision_tree(
cost_complexity = tune(),
tree_depth = 10,
mode = "regression") %>%
set_engine("rpart")

What is my tree depth for my "large" tree?

5 / 13

4. Fit the model on the cross validation set

grid <- expand_grid(cost_complexity = seq(0.01, 0.05, by = 0.01))
model <- tune_grid(tree_spec,
Salary ~ Hits + Years + PutOuts + RBI + Walks + Runs,
grid = grid,
resamples = baseball_cv)
6 / 13

4. Fit the model on the cross validation set

grid <- expand_grid(cost_complexity = seq(0.01, 0.05, by = 0.01))
model <- tune_grid(tree_spec,
Salary ~ Hits + Years + PutOuts + RBI + Walks + Runs,
grid = grid,
resamples = baseball_cv)

What αs am I trying?

6 / 13

5. Choose α that minimizes the RMSE

model %>%
collect_metrics() %>%
filter(.metric == "rmse") %>%
arrange(mean)
## # A tibble: 5 x 6
## cost_complexity .metric .estimator mean n std_err
## <dbl> <chr> <chr> <dbl> <int> <dbl>
## 1 0.03 rmse standard 391. 6 38.5
## 2 0.05 rmse standard 399. 6 38.8
## 3 0.01 rmse standard 399. 6 34.9
## 4 0.02 rmse standard 402. 6 36.2
## 5 0.04 rmse standard 404. 6 36.6
7 / 13

5. Choose α that minimizes the RMSE

model %>%
collect_metrics() %>%
filter(.metric == "rmse") %>%
arrange(mean)
model %>%
select_best(metric = "rmse")
## # A tibble: 1 x 1
## cost_complexity
## <dbl>
## 1 0.03
8 / 13

5. Choose α that minimizes the RMSE

model %>%
collect_metrics() %>%
filter(.metric == "rmse") %>%
arrange(mean)
final_complexity <- model %>%
select_best(metric = "rmse") %>%
pull()
9 / 13

6. Fit the final model

final_spec <- decision_tree(
cost_complexity = final_complexity,
tree_depth = 10,
mode = "regression") %>%
set_engine("rpart")
final_model <- fit(final_spec,
Salary ~ Hits + Years + PutOuts + RBI + Walks + Runs,
data = baseball_train)
10 / 13

Final tree

11 / 13

Final tree

How many terminal nodes does this tree have?

11 / 13

Calculate RMSE on the test data

baseball_test <- testing(baseball_split)
final_model %>%
predict(new_data = baseball_test) %>%
bind_cols(baseball_test) %>%
metrics(truth = Salary, estimate = .pred)
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 363.
## 2 rsq standard 0.356
## 3 mae standard 267.
12 / 13
10:00

AE 05 - Regression trees

  1. Find starter files, beginning with appex-05-regression-trees on GitHub and pull into RStudio
  2. Complete the exercises
  3. Knit, Commit, Push frequently. Be sure to have the final results pushed to GitHub by April 3 at noon
13 / 13

The baseball example

2 / 13
Paused

Help

Keyboard shortcuts

, , Pg Up, k Go to previous slide
, , Pg Dn, Space, j Go to next slide
Home Go to first slide
End Go to last slide
Number + Return Go to specific slide
b / m / f Toggle blackout / mirrored / fullscreen mode
c Clone slideshow
p Toggle presenter mode
t Restart the presentation timer
?, h Toggle this help
Esc Back to slideshow