Practicing predicting with logistic regression for bird bath observation
This is a post reproducing Julia Silge’s tutorial on logistic regression to predict the probablity of seeing birds at bird baths, using the tidymodels framework.
The data is from Tidy Tuesday:
# to download data
bird_baths <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-08-31/bird_baths.csv')
bird_baths %>%
head()
# A tibble: 6 × 5
survey_year urban_rural bioregions bird_type bird_count
<dbl> <chr> <chr> <chr> <dbl>
1 2014 Urban South Eastern Q… Bassian Thrush 0
2 2014 Urban South Eastern Q… Chestnut-breast… 0
3 2014 Urban South Eastern Q… Wild Duck 0
4 2014 Urban South Eastern Q… Willie Wagtail 0
5 2014 Urban South Eastern Q… Regent Bowerbird 0
6 2014 Urban South Eastern Q… Rufous Fantail 0
From the documentation:
bird_baths %>%
skim()
Name | Piped data |
Number of rows | 161057 |
Number of columns | 5 |
_______________________ | |
Column type frequency: | |
character | 3 |
numeric | 2 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
urban_rural | 169 | 1 | 5 | 5 | 0 | 2 | 0 |
bioregions | 169 | 1 | 12 | 24 | 0 | 10 | 0 |
bird_type | 0 | 1 | 4 | 28 | 0 | 169 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
survey_year | 169 | 1 | 2014.45 | 0.50 | 2014 | 2014 | 2014 | 2015 | 2015 | ▇▁▁▁▆ |
bird_count | 0 | 1 | 0.07 | 2.01 | 0 | 0 | 0 | 0 | 292 | ▇▁▁▁▁ |
This is a very long dataset: 161057 rows by 5 variables. There are some missing data for urban_rural and bioregions. Change all year to factor, all character to factor, and bird count to factor (since 0 = not seen, 1 = seen)
glimpse(bird_baths)
Rows: 161,057
Columns: 5
$ survey_year <dbl> 2014, 2014, 2014, 2014, 2014, 2014, 2014, 2014, …
$ urban_rural <chr> "Urban", "Urban", "Urban", "Urban", "Urban", "Ur…
$ bioregions <chr> "South Eastern Queensland", "South Eastern Queen…
$ bird_type <chr> "Bassian Thrush", "Chestnut-breasted Mannikin", …
$ bird_count <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
# A tibble: 169 × 5
survey_year urban_rural bioregions bird_type bird_count
<dbl> <chr> <chr> <chr> <dbl>
1 NA <NA> <NA> Bassian Thrush 4
2 NA <NA> <NA> Chestnut-breasted Ma… 9
3 NA <NA> <NA> Wild Duck 9
4 NA <NA> <NA> Willie Wagtail 114
5 NA <NA> <NA> Regent Bowerbird 11
6 NA <NA> <NA> Rufous Fantail 19
7 NA <NA> <NA> Spiny-cheeked Honeye… 4
8 NA <NA> <NA> Flame Robin 1
9 NA <NA> <NA> European Goldfinch 6
10 NA <NA> <NA> Noisy Friarbird 34
# … with 159 more rows
# A tibble: 169 × 5
survey_year urban_rural bioregions bird_type bird_count
<dbl> <chr> <chr> <chr> <dbl>
1 NA <NA> <NA> Bassian Thrush 4
2 NA <NA> <NA> Chestnut-breasted Ma… 9
3 NA <NA> <NA> Wild Duck 9
4 NA <NA> <NA> Willie Wagtail 114
5 NA <NA> <NA> Regent Bowerbird 11
6 NA <NA> <NA> Rufous Fantail 19
7 NA <NA> <NA> Spiny-cheeked Honeye… 4
8 NA <NA> <NA> Flame Robin 1
9 NA <NA> <NA> European Goldfinch 6
10 NA <NA> <NA> Noisy Friarbird 34
# … with 159 more rows
bird_baths %>%
count(bird_count)
# A tibble: 65 × 2
bird_count n
<dbl> <int>
1 0 155344
2 1 5570
3 2 12
4 3 12
5 4 9
6 5 7
7 6 4
8 7 6
9 8 6
10 9 8
# … with 55 more rows
bird_bath_working <- bird_baths %>%
mutate(across(.cols = survey_year:bird_type , as.factor)) %>%
na.exclude() # remove NA
glimpse(bird_bath_working)
Rows: 160,888
Columns: 5
$ survey_year <fct> 2014, 2014, 2014, 2014, 2014, 2014, 2014, 2014, …
$ urban_rural <fct> Urban, Urban, Urban, Urban, Urban, Urban, Urban,…
$ bioregions <fct> South Eastern Queensland, South Eastern Queensla…
$ bird_type <fct> Bassian Thrush, Chestnut-breasted Mannikin, Wild…
$ bird_count <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
Reusing my ggplot code from previous post:
# Create function:
plot_fct_boxplot <- function(var_x) {
bird_bath_working %>%
mutate(var_x_fct = factor({{var_x}})) %>% # so that other classes can be used for plot as well
ggplot(aes(var_x_fct)) +
geom_bar(aes(fill = var_x_fct), show.legend = F) +
geom_text(stat = "count", aes(label = ..count..), hjust = -0.1) +
scale_fill_jco() + # from ggsci package
scale_y_continuous(expand = expansion(mult = c(0, 0.2))) +
theme_classic() +
labs(title = str_to_title(as_label(enquo(var_x))),
x = as_label(enquo(var_x))) +
theme(axis.title = element_text(face = "bold"),
axis.text.x = element_text(face = "bold")) +
coord_flip()
}
# Set names to loop through:
var_fct <- bird_bath_working %>%
select(survey_year, urban_rural, bioregions, bird_count) %>%
names()
plots <- var_fct %>%
syms() %>%
map(function(var) plot_fct_boxplot({{var}}))
(plots[[1]] + plots[[2]] + plots[[4]]) / plots[[3]]
Bird_count is very imbalanced. Most of the observations come from urban area (about 70%). Most of the data is from Sydney Basin.
glimpse(bird_bath_working)
Rows: 160,888
Columns: 5
$ survey_year <fct> 2014, 2014, 2014, 2014, 2014, 2014, 2014, 2014, …
$ urban_rural <fct> Urban, Urban, Urban, Urban, Urban, Urban, Urban,…
$ bioregions <fct> South Eastern Queensland, South Eastern Queensla…
$ bird_type <fct> Bassian Thrush, Chestnut-breasted Mannikin, Wild…
$ bird_count <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
bird_bath_working %>%
distinct(bird_type) %>%
count() # there are 169 types of birds seen
# A tibble: 1 × 1
n
<int>
1 169
top_15_birds <- bird_bath_working %>%
group_by(bird_type) %>%
summarise(times_seen = sum(bird_count)) %>%
arrange(desc(times_seen)) %>%
slice_head(n = 15)
names_top_birds <- top_15_birds %>%
pull(bird_type)
bird_bath_working %>%
filter(bird_type %in% names_top_birds) %>%
group_by(bird_type, bioregions) %>%
summarise(times_seen = sum(bird_count), .groups = "drop") %>%
arrange(desc(times_seen)) %>%
ggplot(aes(bioregions, fct_rev(bird_type))) +
geom_tile(aes(fill = times_seen)) +
scale_fill_continuous(low = "snow1", high = "darkorange4") +
labs(title = "Heatmap for the top 15 types of birds (by total number of times seen)",
x = "Bio Regions",
y = "Bird type",
caption = "Source: Cleary et al, 2016") +
theme_classic() +
theme(axis.text = element_text(face = "bold"))
bottom_15_birds <- bird_bath_working %>%
group_by(bird_type) %>%
summarise(times_seen = sum(bird_count)) %>%
arrange(desc(times_seen)) %>%
slice_tail(n = 15)
names_rare_birds <- bottom_15_birds %>%
pull(bird_type)
names_rare_birds
[1] Mangrove Gerygone Mangrove Honeyeater
[3] Pale-yellow Robin Purple Swamphen
[5] Purple-crowned Lorikeet Satin Flycatcher
[7] Song Thrush Southern Whiteface
[9] Striated Pardalote Superb Lyrebird
[11] Tawny-crowned Honeyeater Tete
[13] Varied Triller Welcome Swallow
[15] Red-tailed Black Cockatoo
169 Levels: Apostlebird ... Yellow-tufted Honeyeater
bird_bath_working %>%
filter(bird_type %in% names_rare_birds) %>%
group_by(bird_type, bioregions) %>%
summarise(times_seen = sum(bird_count), .groups = "drop") %>%
arrange(desc(times_seen)) %>%
ggplot(aes(bioregions, fct_rev(bird_type))) +
geom_tile(aes(fill = times_seen)) +
scale_fill_continuous(low = "snow1", high = "darkorange4") +
labs(title = "Heatmap for the rarest 15 types of birds (by total number of times seen)",
x = "Bio Regions",
y = "Bird type",
caption = "Source: Cleary et al, 2016") +
theme_classic() +
theme(axis.text = element_text(face = "bold"))
# data that looks at av number of times of seeing birds in urban/rural region
top_15_urban_rural <- bird_bath_working %>%
filter(bird_type %in% names_top_birds) %>%
group_by(urban_rural, bird_type) %>%
summarise(av_times_seen = mean(bird_count), .groups = "drop") %>% # mean is used, not sum
arrange(desc(av_times_seen))
head(top_15_urban_rural)
# A tibble: 6 × 3
urban_rural bird_type av_times_seen
<fct> <fct> <dbl>
1 Rural Superb Fairy-wren 0.364
2 Urban Noisy Miner 0.354
3 Urban Rainbow Lorikeet 0.296
4 Rural Eastern Spinebill 0.296
5 Urban Australian Magpie 0.283
6 Rural Grey Fantail 0.282
# pivot wider so that can use data for geom_segment later
top_15_urban_rural_wide <- top_15_urban_rural %>%
pivot_wider(names_from = urban_rural,
values_from = av_times_seen) %>%
arrange(desc(Urban)) %>%
rowid_to_column("order")
head(top_15_urban_rural_wide)
# A tibble: 6 × 4
order bird_type Rural Urban
<int> <fct> <dbl> <dbl>
1 1 Noisy Miner 0.201 0.354
2 2 Rainbow Lorikeet 0.136 0.296
3 3 Australian Magpie 0.259 0.283
4 4 Spotted Dove 0.0374 0.213
5 5 Red Wattlebird 0.119 0.207
6 6 Magpie-lark 0.112 0.193
top_15_urban_rural %>%
# inner join to see the ordering
left_join(top_15_urban_rural_wide, by = "bird_type") %>%
ggplot(aes(av_times_seen, fct_reorder(bird_type, -order))) +
geom_point(aes(col = urban_rural), size = 3) +
geom_segment(
aes(x = Rural, xend = Urban,
y = bird_type, yend = bird_type),
alpha = 0.7, col = "gray70", size = 1
) +
scale_x_continuous(labels = scales::percent,
limits = c(0, 0.5)) +
scale_color_jco() +
labs(title = "Av number of times seen for top 15 most common types of birds",
subtitle = "Birds are arranged in descending order for probability of seeing in urban areas",
x = "Av number of times seen",
y = "Bird type",
caption = "Source: Cleary et al, 2016",
col = NULL) +
theme_classic() +
theme(legend.position = "top")
Learning point: character variables do not work well for modelling. Must remember to change outcome variable to a factor.
glimpse(bird_bath_working)
Rows: 160,888
Columns: 5
$ survey_year <fct> 2014, 2014, 2014, 2014, 2014, 2014, 2014, 2014, …
$ urban_rural <fct> Urban, Urban, Urban, Urban, Urban, Urban, Urban,…
$ bioregions <fct> South Eastern Queensland, South Eastern Queensla…
$ bird_type <fct> Bassian Thrush, Chestnut-breasted Mannikin, Wild…
$ bird_count <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
bird_bath_working <-
bird_bath_working %>%
mutate(bird_seen = if_else(bird_count > 0, "bird seen", "not seen")) %>% # outcome must be a factor!
mutate(across(where(is.character), as.factor))
set.seed(2021101001)
bird_split <- initial_split(bird_bath_working,
strata = bird_count) # since there is an imbalance
# create a df for train and test data
bird_train_data <- training(bird_split) # 120,666
bird_test_data <- testing(bird_split) # 40,222
# define model:
lr_model <- logistic_reg()
lr_model
Logistic Regression Model Specification (classification)
Computational engine: glm
# create recipe: outcome and two predictors
bird_recipe <- recipe(bird_seen ~ bird_type + urban_rural, data = bird_train_data) %>%
step_dummy(all_nominal_predictors()) # to create dummy variables from factors
summary(bird_recipe)
# A tibble: 3 × 4
variable type role source
<chr> <chr> <chr> <chr>
1 bird_type nominal predictor original
2 urban_rural nominal predictor original
3 bird_seen nominal outcome original
bird_recipe %>% prep()
Recipe
Inputs:
role #variables
outcome 1
predictor 2
Training data contained 120666 data points and no missing data.
Operations:
Dummy variables from bird_type, urban_rural [trained]
# bundle workflow
bird_workflow <-
workflow() %>%
add_model(lr_model) %>%
add_recipe(bird_recipe)
bird_workflow
══ Workflow ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: logistic_reg()
── Preprocessor ──────────────────────────────────────────────────────
1 Recipe Step
• step_dummy()
── Model ─────────────────────────────────────────────────────────────
Logistic Regression Model Specification (classification)
Computational engine: glm
# create cross-validation: to get better estimates of performance
set.seed(2021101002)
bird_folds <- vfold_cv(bird_train_data, v = 10, strata = bird_count) # since there is an imbalance
# fit resamples to workflow:
doParallel::registerDoParallel()
control_pred <- tune::control_resamples(save_pred = T)
lr_fit_rs <- fit_resamples(bird_workflow, bird_folds, control = control_pred)
# collect metrics
collect_metrics(lr_fit_rs) # acc = 0.966, auc = 0.858
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.966 10 0.000338 Preprocessor1_Model1
2 roc_auc binary 0.858 10 0.00307 Preprocessor1_Model1
rs_roc_plot <- augment(lr_fit_rs) %>%
clean_names() %>% # so that pred_bird_seen can be used
roc_curve(bird_seen, pred_bird_seen) %>%
autoplot()
rs_roc_plot
bird_recipe_interact <-
bird_recipe %>%
step_interact(~starts_with("urban_rural"):starts_with("bird_type"))
bird_workflow_interact <-
workflow() %>%
add_model(lr_model) %>%
add_recipe(bird_recipe_interact)
lr_fit_rs_interact <- fit_resamples(bird_workflow_interact,
bird_folds,
control = control_pred)
collect_metrics(lr_fit_rs_interact) # 0.966 acc, 0.869 auc , slightly higher.
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.966 10 0.000338 Preprocessor1_Model1
2 roc_auc binary 0.869 10 0.00260 Preprocessor1_Model1
The AUC is higher for the interact model. However, this model took about 4 hours to run on my laptop….
lr_mod_df <- augment(lr_fit_rs) %>%
clean_names() %>%
roc_curve(bird_seen, pred_bird_seen) %>%
mutate(model = "log reg model")
lr_interact_mod_df <- augment(lr_fit_rs_interact) %>%
clean_names() %>%
roc_curve(bird_seen, pred_bird_seen) %>%
mutate(model = "log reg model with interaction")
bind_rows(lr_mod_df, lr_interact_mod_df) %>%
ggplot(aes(x = 1-specificity, y = sensitivity, col = model)) +
geom_path(lwd = 1.5, alpha = 0.5) +
geom_abline(lty = 3) +
coord_equal() +
scale_color_jco() +
theme_classic()
bird_fit <- parsnip::fit(bird_workflow_interact, bird_test_data)
predict(bird_fit, bird_test_data, type = "prob")
# A tibble: 40,222 × 2
`.pred_bird seen` `.pred_not seen`
<dbl> <dbl>
1 0.0133 0.987
2 0.00581 0.994
3 0.114 0.886
4 0.0629 0.937
5 0.00000000117 1.00
6 0.00000000117 1.00
7 0.195 0.805
8 0.0235 0.976
9 0.131 0.869
10 0.0850 0.915
# … with 40,212 more rows
# Creating new data frame
new_bird_data <-
tibble(bird_type = top_15_birds$bird_type) %>%
crossing(urban_rural = c("Urban", "Rural")) # all unique combinations
# Augment to new data
birds_predicted_new_df <- augment(bird_fit, new_bird_data) %>%
bind_cols(
predict(bird_fit, new_bird_data, type = "conf_int")
) %>%
clean_names()
birds_predicted_new_df
# A tibble: 30 × 9
bird_type urban_rural pred_class pred_bird_seen pred_not_seen
<fct> <chr> <fct> <dbl> <dbl>
1 Australian Magpie Rural not seen 0.205 0.795
2 Australian Magpie Urban not seen 0.254 0.746
3 Crested Pigeon Rural not seen 0.110 0.890
4 Crested Pigeon Urban not seen 0.168 0.832
5 Crimson Rosella Rural not seen 0.231 0.769
6 Crimson Rosella Urban not seen 0.131 0.869
7 Eastern Spinebill Rural not seen 0.321 0.679
8 Eastern Spinebill Urban not seen 0.114 0.886
9 Grey Fantail Rural not seen 0.301 0.699
10 Grey Fantail Urban not seen 0.0909 0.909
# … with 20 more rows, and 4 more variables:
# pred_lower_bird_seen <dbl>, pred_upper_bird_seen <dbl>,
# pred_lower_not_seen <dbl>, pred_upper_not_seen <dbl>
birds_predicted_new_df %>%
filter(bird_type == "Spotted Dove")
# A tibble: 2 × 9
bird_type urban_rural pred_class pred_bird_seen pred_not_seen
<fct> <chr> <fct> <dbl> <dbl>
1 Spotted Dove Rural not seen 0.00000000117 1.00
2 Spotted Dove Urban not seen 0.180 0.820
# … with 4 more variables: pred_lower_bird_seen <dbl>,
# pred_upper_bird_seen <dbl>, pred_lower_not_seen <dbl>,
# pred_upper_not_seen <dbl>
# visualizing
birds_predicted_new_df %>%
ggplot(aes(pred_bird_seen, bird_type, col = urban_rural)) +
geom_errorbar(aes(
xmin = pred_lower_bird_seen,
xmax = pred_upper_bird_seen
),
width = 0.2, size = 1.2, alpha = 0.5) +
geom_point(size = 3) +
scale_x_continuous(labels = scales::percent) +
labs(x = "Predicted probability of seeing bird",
y = "",
col = NULL,
caption = "Caption: Cleary et al, 2016") +
scale_colour_jco() +
theme_classic()
The spotted dove is very hard to spot in rural places!
For attribution, please cite this work as
lruolin (2021, Oct. 8). pRactice corner: Tidy Tuesday Bird Bath Data. Retrieved from https://lruolin.github.io/myBlog/posts/20211008 Tidy Tuesday Bird Bath Log Reg/
BibTeX citation
@misc{lruolin2021tidy, author = {lruolin, }, title = {pRactice corner: Tidy Tuesday Bird Bath Data}, url = {https://lruolin.github.io/myBlog/posts/20211008 Tidy Tuesday Bird Bath Log Reg/}, year = {2021} }