Palmer Penguins

Predicting Sex of Palmer Penguins

lruolin
11-04-2021

Introduction

The codes below are mostly from Julia Silge’s blog (see Reference for link), as I am trying to follow the screencast to be more familiar with tidymodels framework.

Learning points

Load Packages

library(pacman)
p_load(tidymodels, palmerpenguins, janitor, skimr, ggstatsplot, ggsci, patchwork,
       GGally, ranger)

theme_set(theme_classic())

Data

glimpse(penguins) # 344 x 8 col
Rows: 344
Columns: 8
$ species           <fct> Adelie, Adelie, Adelie, Adelie, Adelie, Ad…
$ island            <fct> Torgersen, Torgersen, Torgersen, Torgersen…
$ bill_length_mm    <dbl> 39.1, 39.5, 40.3, NA, 36.7, 39.3, 38.9, 39…
$ bill_depth_mm     <dbl> 18.7, 17.4, 18.0, NA, 19.3, 20.6, 17.8, 19…
$ flipper_length_mm <int> 181, 186, 195, NA, 193, 190, 181, 195, 193…
$ body_mass_g       <int> 3750, 3800, 3250, NA, 3450, 3650, 3625, 46…
$ sex               <fct> male, female, female, NA, female, male, fe…
$ year              <int> 2007, 2007, 2007, 2007, 2007, 2007, 2007, …

Check for missing values

penguins_orig <- penguins

sapply(penguins_orig, function(x) sum(is.na(x)))
          species            island    bill_length_mm 
                0                 0                 2 
    bill_depth_mm flipper_length_mm       body_mass_g 
                2                 2                 2 
              sex              year 
               11                 0 
skim(penguins_orig) # need to exclude missing values
Table 1: Data summary
Name penguins_orig
Number of rows 344
Number of columns 8
_______________________
Column type frequency:
factor 3
numeric 5
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
species 0 1.00 FALSE 3 Ade: 152, Gen: 124, Chi: 68
island 0 1.00 FALSE 3 Bis: 168, Dre: 124, Tor: 52
sex 11 0.97 FALSE 2 mal: 168, fem: 165

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
bill_length_mm 2 0.99 43.92 5.46 32.1 39.23 44.45 48.5 59.6 ▃▇▇▆▁
bill_depth_mm 2 0.99 17.15 1.97 13.1 15.60 17.30 18.7 21.5 ▅▅▇▇▂
flipper_length_mm 2 0.99 200.92 14.06 172.0 190.00 197.00 213.0 231.0 ▂▇▃▅▂
body_mass_g 2 0.99 4201.75 801.95 2700.0 3550.00 4050.00 4750.0 6300.0 ▃▇▆▃▂
year 0 1.00 2008.03 0.82 2007.0 2007.00 2008.00 2009.0 2009.0 ▇▁▇▁▇
penguins <- penguins_orig %>% 
  drop_na()

Summary statistics

summary(penguins)
      species          island    bill_length_mm  bill_depth_mm  
 Adelie   :146   Biscoe   :163   Min.   :32.10   Min.   :13.10  
 Chinstrap: 68   Dream    :123   1st Qu.:39.50   1st Qu.:15.60  
 Gentoo   :119   Torgersen: 47   Median :44.50   Median :17.30  
                                 Mean   :43.99   Mean   :17.16  
                                 3rd Qu.:48.60   3rd Qu.:18.70  
                                 Max.   :59.60   Max.   :21.50  
 flipper_length_mm  body_mass_g       sex           year     
 Min.   :172       Min.   :2700   female:165   Min.   :2007  
 1st Qu.:190       1st Qu.:3550   male  :168   1st Qu.:2007  
 Median :197       Median :4050                Median :2008  
 Mean   :201       Mean   :4207                Mean   :2008  
 3rd Qu.:213       3rd Qu.:4775                3rd Qu.:2009  
 Max.   :231       Max.   :6300                Max.   :2009  

EDA

penguins %>% 
  group_by(species) %>% 
  summarise(n = n(),
            .groups = "drop") %>% 
  mutate(prop = n/sum(n)) %>% 
  ggplot(aes(species, prop, label = scales::percent(prop))) +
  geom_col(aes(fill = species), show.legend = F) +
  geom_text(vjust = -1) +
  scale_fill_jco() +
  labs(title = "Species")
plot_fcts <- function(var_x) {
  penguins %>% 
  group_by({{var_x}}) %>% 
  summarise(n = n(),
            .groups = "drop") %>% 
  mutate(prop = n/sum(n)) %>% 
  ggplot(aes({{var_x}}, prop, label = scales::percent(prop))) +
  geom_col(aes(fill = {{var_x}}), show.legend = F) +
  geom_text(vjust = -1) +
  scale_fill_jco() +
  labs(title = as_label(enquo(var_x)))

}

plot_fcts(species)
plot_fcts(island)
plot_fcts(sex) # removed na from data, almost 50-50% 

penguins %>% 
  ggplot(aes(sex, bill_length_mm)) +
  geom_boxplot() +
  geom_point() 
plot_boxplots <- function(fct_x, cont_y) {
  penguins %>% 
    ggplot(aes({{fct_x}}, {{cont_y}})) +
    geom_boxplot(aes(fill = {{fct_x}}), show.legend = F) +
    geom_jitter(alpha = 0.5, aes(col = {{fct_x}}), show.legend = F) +
    scale_fill_jco() +
    scale_color_jco()
  
}

By Species

bp1 <- plot_boxplots(species, bill_length_mm)
bp2 <- plot_boxplots(species, bill_depth_mm)
bp3 <- plot_boxplots(species, flipper_length_mm)
bp4 <- plot_boxplots(species, body_mass_g)

bp_species <- (bp1 + bp2) / (bp3 + bp4)
bp_species

By Gender

bp5 <- plot_boxplots(sex, bill_length_mm)
bp6 <- plot_boxplots(sex, bill_depth_mm)
bp7 <- plot_boxplots(sex, flipper_length_mm)
bp8 <- plot_boxplots(sex, body_mass_g)

bp_sex <- (bp5 + bp6) / (bp7 + bp8)
bp_sex

Correlation

penguins %>% 
  select_if(is.numeric) %>% 
  ggstatsplot::ggcorrmat()

penguins %>% 
  select(-year) %>% 
  ggpairs(aes(col = sex)) +
  scale_color_jco() +
  scale_fill_jco()

Split data

set.seed(2021110301)

penguin_split <- initial_split(penguins, strata = sex)
penguin_train <- training(penguin_split)
penguin_test <- testing(penguin_split)

Create recipe

penguin_recipe <- recipe(sex ~ ., data = penguins) %>% 
  update_role(island, year, new_role = "dummy")

summary(penguin_recipe)
# A tibble: 8 × 4
  variable          type    role      source  
  <chr>             <chr>   <chr>     <chr>   
1 species           nominal predictor original
2 island            nominal dummy     original
3 bill_length_mm    numeric predictor original
4 bill_depth_mm     numeric predictor original
5 flipper_length_mm numeric predictor original
6 body_mass_g       numeric predictor original
7 year              numeric dummy     original
8 sex               nominal outcome   original

Create bootstrap

set.seed(2021110302)
penguin_boot <- bootstraps(penguin_train)
penguin_boot
# Bootstrap sampling 
# A tibble: 25 × 2
   splits           id         
   <list>           <chr>      
 1 <split [249/85]> Bootstrap01
 2 <split [249/88]> Bootstrap02
 3 <split [249/96]> Bootstrap03
 4 <split [249/96]> Bootstrap04
 5 <split [249/98]> Bootstrap05
 6 <split [249/96]> Bootstrap06
 7 <split [249/85]> Bootstrap07
 8 <split [249/92]> Bootstrap08
 9 <split [249/80]> Bootstrap09
10 <split [249/88]> Bootstrap10
# … with 15 more rows

Log Reg Model

lr_model <- logistic_reg() %>% 
  set_engine("glm")
lr_workflow <- workflow() %>% 
  add_model(lr_model) %>% 
  add_recipe(penguin_recipe)
lr_fit_resamples <- lr_workflow %>% 
  fit_resamples(
    resamples = penguin_boot,
    control = control_resamples(save_pred = T)
  )

lr_fit_resamples
# Resampling results
# Bootstrap sampling 
# A tibble: 25 × 5
   splits           id          .metrics         .notes   .predictions
   <list>           <chr>       <list>           <list>   <list>      
 1 <split [249/85]> Bootstrap01 <tibble [2 × 4]> <tibble… <tibble [85…
 2 <split [249/88]> Bootstrap02 <tibble [2 × 4]> <tibble… <tibble [88…
 3 <split [249/96]> Bootstrap03 <tibble [2 × 4]> <tibble… <tibble [96…
 4 <split [249/96]> Bootstrap04 <tibble [2 × 4]> <tibble… <tibble [96…
 5 <split [249/98]> Bootstrap05 <tibble [2 × 4]> <tibble… <tibble [98…
 6 <split [249/96]> Bootstrap06 <tibble [2 × 4]> <tibble… <tibble [96…
 7 <split [249/85]> Bootstrap07 <tibble [2 × 4]> <tibble… <tibble [85…
 8 <split [249/92]> Bootstrap08 <tibble [2 × 4]> <tibble… <tibble [92…
 9 <split [249/80]> Bootstrap09 <tibble [2 × 4]> <tibble… <tibble [80…
10 <split [249/88]> Bootstrap10 <tibble [2 × 4]> <tibble… <tibble [88…
# … with 15 more rows

Random Forest Model

rf_model <- rand_forest() %>% 
  set_mode("classification") %>% 
  set_engine("ranger")

rf_model
Random Forest Model Specification (classification)

Computational engine: ranger 
rf_workflow <- workflow() %>% 
  add_model(rf_model) %>% 
  add_recipe(penguin_recipe)
rf_fit_resamples <- rf_workflow %>% 
  fit_resamples(
    resamples = penguin_boot,
    control = control_resamples(save_pred = T)
  )

rf_fit_resamples
# Resampling results
# Bootstrap sampling 
# A tibble: 25 × 5
   splits           id          .metrics         .notes   .predictions
   <list>           <chr>       <list>           <list>   <list>      
 1 <split [249/85]> Bootstrap01 <tibble [2 × 4]> <tibble… <tibble [85…
 2 <split [249/88]> Bootstrap02 <tibble [2 × 4]> <tibble… <tibble [88…
 3 <split [249/96]> Bootstrap03 <tibble [2 × 4]> <tibble… <tibble [96…
 4 <split [249/96]> Bootstrap04 <tibble [2 × 4]> <tibble… <tibble [96…
 5 <split [249/98]> Bootstrap05 <tibble [2 × 4]> <tibble… <tibble [98…
 6 <split [249/96]> Bootstrap06 <tibble [2 × 4]> <tibble… <tibble [96…
 7 <split [249/85]> Bootstrap07 <tibble [2 × 4]> <tibble… <tibble [85…
 8 <split [249/92]> Bootstrap08 <tibble [2 × 4]> <tibble… <tibble [92…
 9 <split [249/80]> Bootstrap09 <tibble [2 × 4]> <tibble… <tibble [80…
10 <split [249/88]> Bootstrap10 <tibble [2 × 4]> <tibble… <tibble [88…
# … with 15 more rows

Evaluate models

collect_metrics(lr_fit_resamples)
# A tibble: 2 × 6
  .metric  .estimator  mean     n std_err .config             
  <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy binary     0.911    25 0.00492 Preprocessor1_Model1
2 roc_auc  binary     0.974    25 0.00201 Preprocessor1_Model1
collect_metrics(rf_fit_resamples)
# A tibble: 2 × 6
  .metric  .estimator  mean     n std_err .config             
  <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
1 accuracy binary     0.897    25 0.00646 Preprocessor1_Model1
2 roc_auc  binary     0.967    25 0.00238 Preprocessor1_Model1

Logistic Reg model performed better.

Confusion matrix on resampled data

lr_fit_resamples %>% 
  conf_mat_resampled()
# A tibble: 4 × 3
  Prediction Truth   Freq
  <fct>      <fct>  <dbl>
1 female     female 40.7 
2 female     male    3.56
3 male       female  4.56
4 male       male   42.6 

ROC curve for resampled data

lr_fit_resamples %>% 
  collect_predictions() %>% 
  group_by(id) %>% 
  roc_curve(sex, .pred_female) %>% 
  ggplot(aes(1-specificity, sensitivity, col = id)) +
  geom_abline(lty = 2, size = 1.5) +
  geom_path(show.legend = F, alpha = 0.6, size = 1.2) +
  coord_equal()

Evaluate on testing dataset - Log Reg

lr_penguin_final <- lr_workflow %>% 
  last_fit(penguin_split)

lr_penguin_final
# Resampling results
# Manual resampling 
# A tibble: 1 × 6
  splits           id       .metrics  .notes   .predictions  .workflow
  <list>           <chr>    <list>    <list>   <list>        <list>   
1 <split [249/84]> train/t… <tibble … <tibble… <tibble [84 … <workflo…
collect_metrics(lr_penguin_final)
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.893 Preprocessor1_Model1
2 roc_auc  binary         0.965 Preprocessor1_Model1
collect_predictions(lr_penguin_final) %>% 
  conf_mat(sex, .pred_class)
          Truth
Prediction female male
    female     37    4
    male        5   38
lr_penguin_final$.workflow[[1]] %>% 
  tidy(exponentiate = T) %>% 
  arrange(desc(estimate))
# A tibble: 7 × 5
  term              estimate std.error statistic      p.value
  <chr>                <dbl>     <dbl>     <dbl>        <dbl>
1 bill_depth_mm     6.87e+ 0   0.418       4.61  0.00000412  
2 bill_length_mm    1.75e+ 0   0.146       3.84  0.000122    
3 flipper_length_mm 1.02e+ 0   0.0567      0.401 0.689       
4 body_mass_g       1.01e+ 0   0.00149     4.59  0.00000451  
5 speciesChinstrap  1.73e- 3   1.81       -3.52  0.000432    
6 speciesGentoo     2.40e- 4   3.18       -2.62  0.00878     
7 (Intercept)       3.04e-38  15.5        -5.57  0.0000000251

An increase of 1mm in bill depth will correspond to almost 6x higher odds of being male.

Evaluate on testing dataset - Random Forest

rf_penguin_final <- rf_workflow %>% 
  last_fit(penguin_split)

rf_penguin_final
# Resampling results
# Manual resampling 
# A tibble: 1 × 6
  splits           id       .metrics  .notes   .predictions  .workflow
  <list>           <chr>    <list>    <list>   <list>        <list>   
1 <split [249/84]> train/t… <tibble … <tibble… <tibble [84 … <workflo…
collect_metrics(rf_penguin_final)
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary         0.893 Preprocessor1_Model1
2 roc_auc  binary         0.951 Preprocessor1_Model1
collect_predictions(rf_penguin_final) %>% 
  conf_mat(sex, .pred_class)
          Truth
Prediction female male
    female     36    3
    male        6   39

Log Reg is a simpler and more interpretable model as compared to random forest.

Visualization to confirm

penguins %>% 
  ggplot(aes(bill_depth_mm, bill_length_mm, col = sex)) +
  geom_point() +
  facet_wrap(~species) +
  scale_color_jco()

Reference:

-https://juliasilge.com/blog/palmer-penguins/

Citation

For attribution, please cite this work as

lruolin (2021, Nov. 4). pRactice corner: Palmer Penguins. Retrieved from https://lruolin.github.io/myBlog/posts/2021103 Palmer Penguins - Predicting Sex/

BibTeX citation

@misc{lruolin2021palmer,
  author = {lruolin, },
  title = {pRactice corner: Palmer Penguins},
  url = {https://lruolin.github.io/myBlog/posts/2021103 Palmer Penguins - Predicting Sex/},
  year = {2021}
}