Predicting Sex of Palmer Penguins
The codes below are mostly from Julia Silge’s blog (see Reference for link), as I am trying to follow the screencast to be more familiar with tidymodels framework.
glimpse(penguins) # 344 x 8 col
Rows: 344
Columns: 8
$ species <fct> Adelie, Adelie, Adelie, Adelie, Adelie, Ad…
$ island <fct> Torgersen, Torgersen, Torgersen, Torgersen…
$ bill_length_mm <dbl> 39.1, 39.5, 40.3, NA, 36.7, 39.3, 38.9, 39…
$ bill_depth_mm <dbl> 18.7, 17.4, 18.0, NA, 19.3, 20.6, 17.8, 19…
$ flipper_length_mm <int> 181, 186, 195, NA, 193, 190, 181, 195, 193…
$ body_mass_g <int> 3750, 3800, 3250, NA, 3450, 3650, 3625, 46…
$ sex <fct> male, female, female, NA, female, male, fe…
$ year <int> 2007, 2007, 2007, 2007, 2007, 2007, 2007, …
species island bill_length_mm
0 0 2
bill_depth_mm flipper_length_mm body_mass_g
2 2 2
sex year
11 0
skim(penguins_orig) # need to exclude missing values
Name | penguins_orig |
Number of rows | 344 |
Number of columns | 8 |
_______________________ | |
Column type frequency: | |
factor | 3 |
numeric | 5 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
species | 0 | 1.00 | FALSE | 3 | Ade: 152, Gen: 124, Chi: 68 |
island | 0 | 1.00 | FALSE | 3 | Bis: 168, Dre: 124, Tor: 52 |
sex | 11 | 0.97 | FALSE | 2 | mal: 168, fem: 165 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
bill_length_mm | 2 | 0.99 | 43.92 | 5.46 | 32.1 | 39.23 | 44.45 | 48.5 | 59.6 | ▃▇▇▆▁ |
bill_depth_mm | 2 | 0.99 | 17.15 | 1.97 | 13.1 | 15.60 | 17.30 | 18.7 | 21.5 | ▅▅▇▇▂ |
flipper_length_mm | 2 | 0.99 | 200.92 | 14.06 | 172.0 | 190.00 | 197.00 | 213.0 | 231.0 | ▂▇▃▅▂ |
body_mass_g | 2 | 0.99 | 4201.75 | 801.95 | 2700.0 | 3550.00 | 4050.00 | 4750.0 | 6300.0 | ▃▇▆▃▂ |
year | 0 | 1.00 | 2008.03 | 0.82 | 2007.0 | 2007.00 | 2008.00 | 2009.0 | 2009.0 | ▇▁▇▁▇ |
penguins <- penguins_orig %>%
drop_na()
summary(penguins)
species island bill_length_mm bill_depth_mm
Adelie :146 Biscoe :163 Min. :32.10 Min. :13.10
Chinstrap: 68 Dream :123 1st Qu.:39.50 1st Qu.:15.60
Gentoo :119 Torgersen: 47 Median :44.50 Median :17.30
Mean :43.99 Mean :17.16
3rd Qu.:48.60 3rd Qu.:18.70
Max. :59.60 Max. :21.50
flipper_length_mm body_mass_g sex year
Min. :172 Min. :2700 female:165 Min. :2007
1st Qu.:190 1st Qu.:3550 male :168 1st Qu.:2007
Median :197 Median :4050 Median :2008
Mean :201 Mean :4207 Mean :2008
3rd Qu.:213 3rd Qu.:4775 3rd Qu.:2009
Max. :231 Max. :6300 Max. :2009
penguins %>%
group_by(species) %>%
summarise(n = n(),
.groups = "drop") %>%
mutate(prop = n/sum(n)) %>%
ggplot(aes(species, prop, label = scales::percent(prop))) +
geom_col(aes(fill = species), show.legend = F) +
geom_text(vjust = -1) +
scale_fill_jco() +
labs(title = "Species")
plot_fcts <- function(var_x) {
penguins %>%
group_by({{var_x}}) %>%
summarise(n = n(),
.groups = "drop") %>%
mutate(prop = n/sum(n)) %>%
ggplot(aes({{var_x}}, prop, label = scales::percent(prop))) +
geom_col(aes(fill = {{var_x}}), show.legend = F) +
geom_text(vjust = -1) +
scale_fill_jco() +
labs(title = as_label(enquo(var_x)))
}
plot_fcts(species)
plot_fcts(island)
plot_fcts(sex) # removed na from data, almost 50-50%
penguins %>%
ggplot(aes(sex, bill_length_mm)) +
geom_boxplot() +
geom_point()
plot_boxplots <- function(fct_x, cont_y) {
penguins %>%
ggplot(aes({{fct_x}}, {{cont_y}})) +
geom_boxplot(aes(fill = {{fct_x}}), show.legend = F) +
geom_jitter(alpha = 0.5, aes(col = {{fct_x}}), show.legend = F) +
scale_fill_jco() +
scale_color_jco()
}
By Species
bp1 <- plot_boxplots(species, bill_length_mm)
bp2 <- plot_boxplots(species, bill_depth_mm)
bp3 <- plot_boxplots(species, flipper_length_mm)
bp4 <- plot_boxplots(species, body_mass_g)
bp_species <- (bp1 + bp2) / (bp3 + bp4)
bp_species
By Gender
bp5 <- plot_boxplots(sex, bill_length_mm)
bp6 <- plot_boxplots(sex, bill_depth_mm)
bp7 <- plot_boxplots(sex, flipper_length_mm)
bp8 <- plot_boxplots(sex, body_mass_g)
bp_sex <- (bp5 + bp6) / (bp7 + bp8)
bp_sex
penguins %>%
select_if(is.numeric) %>%
ggstatsplot::ggcorrmat()
penguins %>%
select(-year) %>%
ggpairs(aes(col = sex)) +
scale_color_jco() +
scale_fill_jco()
set.seed(2021110301)
penguin_split <- initial_split(penguins, strata = sex)
penguin_train <- training(penguin_split)
penguin_test <- testing(penguin_split)
penguin_recipe <- recipe(sex ~ ., data = penguins) %>%
update_role(island, year, new_role = "dummy")
summary(penguin_recipe)
# A tibble: 8 × 4
variable type role source
<chr> <chr> <chr> <chr>
1 species nominal predictor original
2 island nominal dummy original
3 bill_length_mm numeric predictor original
4 bill_depth_mm numeric predictor original
5 flipper_length_mm numeric predictor original
6 body_mass_g numeric predictor original
7 year numeric dummy original
8 sex nominal outcome original
set.seed(2021110302)
penguin_boot <- bootstraps(penguin_train)
penguin_boot
# Bootstrap sampling
# A tibble: 25 × 2
splits id
<list> <chr>
1 <split [249/85]> Bootstrap01
2 <split [249/88]> Bootstrap02
3 <split [249/96]> Bootstrap03
4 <split [249/96]> Bootstrap04
5 <split [249/98]> Bootstrap05
6 <split [249/96]> Bootstrap06
7 <split [249/85]> Bootstrap07
8 <split [249/92]> Bootstrap08
9 <split [249/80]> Bootstrap09
10 <split [249/88]> Bootstrap10
# … with 15 more rows
lr_model <- logistic_reg() %>%
set_engine("glm")
lr_workflow <- workflow() %>%
add_model(lr_model) %>%
add_recipe(penguin_recipe)
lr_fit_resamples <- lr_workflow %>%
fit_resamples(
resamples = penguin_boot,
control = control_resamples(save_pred = T)
)
lr_fit_resamples
# Resampling results
# Bootstrap sampling
# A tibble: 25 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [249/85]> Bootstrap01 <tibble [2 × 4]> <tibble… <tibble [85…
2 <split [249/88]> Bootstrap02 <tibble [2 × 4]> <tibble… <tibble [88…
3 <split [249/96]> Bootstrap03 <tibble [2 × 4]> <tibble… <tibble [96…
4 <split [249/96]> Bootstrap04 <tibble [2 × 4]> <tibble… <tibble [96…
5 <split [249/98]> Bootstrap05 <tibble [2 × 4]> <tibble… <tibble [98…
6 <split [249/96]> Bootstrap06 <tibble [2 × 4]> <tibble… <tibble [96…
7 <split [249/85]> Bootstrap07 <tibble [2 × 4]> <tibble… <tibble [85…
8 <split [249/92]> Bootstrap08 <tibble [2 × 4]> <tibble… <tibble [92…
9 <split [249/80]> Bootstrap09 <tibble [2 × 4]> <tibble… <tibble [80…
10 <split [249/88]> Bootstrap10 <tibble [2 × 4]> <tibble… <tibble [88…
# … with 15 more rows
rf_model <- rand_forest() %>%
set_mode("classification") %>%
set_engine("ranger")
rf_model
Random Forest Model Specification (classification)
Computational engine: ranger
rf_workflow <- workflow() %>%
add_model(rf_model) %>%
add_recipe(penguin_recipe)
rf_fit_resamples <- rf_workflow %>%
fit_resamples(
resamples = penguin_boot,
control = control_resamples(save_pred = T)
)
rf_fit_resamples
# Resampling results
# Bootstrap sampling
# A tibble: 25 × 5
splits id .metrics .notes .predictions
<list> <chr> <list> <list> <list>
1 <split [249/85]> Bootstrap01 <tibble [2 × 4]> <tibble… <tibble [85…
2 <split [249/88]> Bootstrap02 <tibble [2 × 4]> <tibble… <tibble [88…
3 <split [249/96]> Bootstrap03 <tibble [2 × 4]> <tibble… <tibble [96…
4 <split [249/96]> Bootstrap04 <tibble [2 × 4]> <tibble… <tibble [96…
5 <split [249/98]> Bootstrap05 <tibble [2 × 4]> <tibble… <tibble [98…
6 <split [249/96]> Bootstrap06 <tibble [2 × 4]> <tibble… <tibble [96…
7 <split [249/85]> Bootstrap07 <tibble [2 × 4]> <tibble… <tibble [85…
8 <split [249/92]> Bootstrap08 <tibble [2 × 4]> <tibble… <tibble [92…
9 <split [249/80]> Bootstrap09 <tibble [2 × 4]> <tibble… <tibble [80…
10 <split [249/88]> Bootstrap10 <tibble [2 × 4]> <tibble… <tibble [88…
# … with 15 more rows
collect_metrics(lr_fit_resamples)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.911 25 0.00492 Preprocessor1_Model1
2 roc_auc binary 0.974 25 0.00201 Preprocessor1_Model1
collect_metrics(rf_fit_resamples)
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.897 25 0.00646 Preprocessor1_Model1
2 roc_auc binary 0.967 25 0.00238 Preprocessor1_Model1
Logistic Reg model performed better.
lr_fit_resamples %>%
conf_mat_resampled()
# A tibble: 4 × 3
Prediction Truth Freq
<fct> <fct> <dbl>
1 female female 40.7
2 female male 3.56
3 male female 4.56
4 male male 42.6
lr_fit_resamples %>%
collect_predictions() %>%
group_by(id) %>%
roc_curve(sex, .pred_female) %>%
ggplot(aes(1-specificity, sensitivity, col = id)) +
geom_abline(lty = 2, size = 1.5) +
geom_path(show.legend = F, alpha = 0.6, size = 1.2) +
coord_equal()
lr_penguin_final <- lr_workflow %>%
last_fit(penguin_split)
lr_penguin_final
# Resampling results
# Manual resampling
# A tibble: 1 × 6
splits id .metrics .notes .predictions .workflow
<list> <chr> <list> <list> <list> <list>
1 <split [249/84]> train/t… <tibble … <tibble… <tibble [84 … <workflo…
collect_metrics(lr_penguin_final)
# A tibble: 2 × 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 accuracy binary 0.893 Preprocessor1_Model1
2 roc_auc binary 0.965 Preprocessor1_Model1
collect_predictions(lr_penguin_final) %>%
conf_mat(sex, .pred_class)
Truth
Prediction female male
female 37 4
male 5 38
lr_penguin_final$.workflow[[1]] %>%
tidy(exponentiate = T) %>%
arrange(desc(estimate))
# A tibble: 7 × 5
term estimate std.error statistic p.value
<chr> <dbl> <dbl> <dbl> <dbl>
1 bill_depth_mm 6.87e+ 0 0.418 4.61 0.00000412
2 bill_length_mm 1.75e+ 0 0.146 3.84 0.000122
3 flipper_length_mm 1.02e+ 0 0.0567 0.401 0.689
4 body_mass_g 1.01e+ 0 0.00149 4.59 0.00000451
5 speciesChinstrap 1.73e- 3 1.81 -3.52 0.000432
6 speciesGentoo 2.40e- 4 3.18 -2.62 0.00878
7 (Intercept) 3.04e-38 15.5 -5.57 0.0000000251
An increase of 1mm in bill depth will correspond to almost 6x higher odds of being male.
rf_penguin_final <- rf_workflow %>%
last_fit(penguin_split)
rf_penguin_final
# Resampling results
# Manual resampling
# A tibble: 1 × 6
splits id .metrics .notes .predictions .workflow
<list> <chr> <list> <list> <list> <list>
1 <split [249/84]> train/t… <tibble … <tibble… <tibble [84 … <workflo…
collect_metrics(rf_penguin_final)
# A tibble: 2 × 4
.metric .estimator .estimate .config
<chr> <chr> <dbl> <chr>
1 accuracy binary 0.893 Preprocessor1_Model1
2 roc_auc binary 0.951 Preprocessor1_Model1
collect_predictions(rf_penguin_final) %>%
conf_mat(sex, .pred_class)
Truth
Prediction female male
female 36 3
male 6 39
Log Reg is a simpler and more interpretable model as compared to random forest.
penguins %>%
ggplot(aes(bill_depth_mm, bill_length_mm, col = sex)) +
geom_point() +
facet_wrap(~species) +
scale_color_jco()
-https://juliasilge.com/blog/palmer-penguins/
For attribution, please cite this work as
lruolin (2021, Nov. 4). pRactice corner: Palmer Penguins. Retrieved from https://lruolin.github.io/myBlog/posts/2021103 Palmer Penguins - Predicting Sex/
BibTeX citation
@misc{lruolin2021palmer, author = {lruolin, }, title = {pRactice corner: Palmer Penguins}, url = {https://lruolin.github.io/myBlog/posts/2021103 Palmer Penguins - Predicting Sex/}, year = {2021} }