Overview
Logistic regression is a method for estimating the probability that an observation is in one of two classes given a vector of covariates. For example, given various demographic characteristics (age, sex, etc…), we can estimate the probability that a person owns a home or not. Important predictors would likely be age and level of income. The probabilities returned by the model can then be used for classification based on a cutoff (e.g., 0.5 is a common choice) or used for decision making based on more complex decision rules1.
1 This article goes into more detail on the difference between prediction of probabilities and classification.
In this post, we’ll explore how logistic regression works by implementing it by hand using a few different methods. It can be bit of a black box using the built-in functions in R, so implementing algorithms by hand can aid understanding, even though it’s not practical for data analysis projects.
Data
As an example dataset, we will use the Palmer Penguins data. The data includes measurements on three penguins species from an island in the Palmer Archipelago.
We’ll load the data and save it as a data frame df
.
Code
# Load and rename data
data(penguins)
df <- penguins
Then we will
- filter to two of the penguin species: Adelie and Gentoo
- create a binary variable,
adelie
corresponding to 1 for Adelie and 0 for Gentoo - select a subset of columns to keep
Exploratory data analysis
You can explore the raw data below.
species | adelie | bill_length_mm | bill_depth_mm | flipper_length_mm | body_mass_g |
---|---|---|---|---|---|
Adelie | 1 | 39.1 | 18.7 | 181 | 3750 |
Adelie | 1 | 39.5 | 17.4 | 186 | 3800 |
Adelie | 1 | 40.3 | 18.0 | 195 | 3250 |
Adelie | 1 | 36.7 | 19.3 | 193 | 3450 |
Adelie | 1 | 39.3 | 20.6 | 190 | 3650 |
Adelie | 1 | 38.9 | 17.8 | 181 | 3625 |
Adelie | 1 | 39.2 | 19.6 | 195 | 4675 |
Adelie | 1 | 34.1 | 18.1 | 193 | 3475 |
Adelie | 1 | 42.0 | 20.2 | 190 | 4250 |
Adelie | 1 | 37.8 | 17.1 | 186 | 3300 |
Adelie | 1 | 37.8 | 17.3 | 180 | 3700 |
Adelie | 1 | 41.1 | 17.6 | 182 | 3200 |
Adelie | 1 | 38.6 | 21.2 | 191 | 3800 |
Adelie | 1 | 34.6 | 21.1 | 198 | 4400 |
Adelie | 1 | 36.6 | 17.8 | 185 | 3700 |
Adelie | 1 | 38.7 | 19.0 | 195 | 3450 |
Adelie | 1 | 42.5 | 20.7 | 197 | 4500 |
Adelie | 1 | 34.4 | 18.4 | 184 | 3325 |
Adelie | 1 | 46.0 | 21.5 | 194 | 4200 |
Adelie | 1 | 37.8 | 18.3 | 174 | 3400 |
Adelie | 1 | 37.7 | 18.7 | 180 | 3600 |
Adelie | 1 | 35.9 | 19.2 | 189 | 3800 |
Adelie | 1 | 38.2 | 18.1 | 185 | 3950 |
Adelie | 1 | 38.8 | 17.2 | 180 | 3800 |
Adelie | 1 | 35.3 | 18.9 | 187 | 3800 |
Adelie | 1 | 40.6 | 18.6 | 183 | 3550 |
Adelie | 1 | 40.5 | 17.9 | 187 | 3200 |
Adelie | 1 | 37.9 | 18.6 | 172 | 3150 |
Adelie | 1 | 40.5 | 18.9 | 180 | 3950 |
Adelie | 1 | 39.5 | 16.7 | 178 | 3250 |
Adelie | 1 | 37.2 | 18.1 | 178 | 3900 |
Adelie | 1 | 39.5 | 17.8 | 188 | 3300 |
Adelie | 1 | 40.9 | 18.9 | 184 | 3900 |
Adelie | 1 | 36.4 | 17.0 | 195 | 3325 |
Adelie | 1 | 39.2 | 21.1 | 196 | 4150 |
Adelie | 1 | 38.8 | 20.0 | 190 | 3950 |
Adelie | 1 | 42.2 | 18.5 | 180 | 3550 |
Adelie | 1 | 37.6 | 19.3 | 181 | 3300 |
Adelie | 1 | 39.8 | 19.1 | 184 | 4650 |
Adelie | 1 | 36.5 | 18.0 | 182 | 3150 |
Adelie | 1 | 40.8 | 18.4 | 195 | 3900 |
Adelie | 1 | 36.0 | 18.5 | 186 | 3100 |
Adelie | 1 | 44.1 | 19.7 | 196 | 4400 |
Adelie | 1 | 37.0 | 16.9 | 185 | 3000 |
Adelie | 1 | 39.6 | 18.8 | 190 | 4600 |
Adelie | 1 | 41.1 | 19.0 | 182 | 3425 |
Adelie | 1 | 37.5 | 18.9 | 179 | 2975 |
Adelie | 1 | 36.0 | 17.9 | 190 | 3450 |
Adelie | 1 | 42.3 | 21.2 | 191 | 4150 |
Adelie | 1 | 39.6 | 17.7 | 186 | 3500 |
Adelie | 1 | 40.1 | 18.9 | 188 | 4300 |
Adelie | 1 | 35.0 | 17.9 | 190 | 3450 |
Adelie | 1 | 42.0 | 19.5 | 200 | 4050 |
Adelie | 1 | 34.5 | 18.1 | 187 | 2900 |
Adelie | 1 | 41.4 | 18.6 | 191 | 3700 |
Adelie | 1 | 39.0 | 17.5 | 186 | 3550 |
Adelie | 1 | 40.6 | 18.8 | 193 | 3800 |
Adelie | 1 | 36.5 | 16.6 | 181 | 2850 |
Adelie | 1 | 37.6 | 19.1 | 194 | 3750 |
Adelie | 1 | 35.7 | 16.9 | 185 | 3150 |
Adelie | 1 | 41.3 | 21.1 | 195 | 4400 |
Adelie | 1 | 37.6 | 17.0 | 185 | 3600 |
Adelie | 1 | 41.1 | 18.2 | 192 | 4050 |
Adelie | 1 | 36.4 | 17.1 | 184 | 2850 |
Adelie | 1 | 41.6 | 18.0 | 192 | 3950 |
Adelie | 1 | 35.5 | 16.2 | 195 | 3350 |
Adelie | 1 | 41.1 | 19.1 | 188 | 4100 |
Adelie | 1 | 35.9 | 16.6 | 190 | 3050 |
Adelie | 1 | 41.8 | 19.4 | 198 | 4450 |
Adelie | 1 | 33.5 | 19.0 | 190 | 3600 |
Adelie | 1 | 39.7 | 18.4 | 190 | 3900 |
Adelie | 1 | 39.6 | 17.2 | 196 | 3550 |
Adelie | 1 | 45.8 | 18.9 | 197 | 4150 |
Adelie | 1 | 35.5 | 17.5 | 190 | 3700 |
Adelie | 1 | 42.8 | 18.5 | 195 | 4250 |
Adelie | 1 | 40.9 | 16.8 | 191 | 3700 |
Adelie | 1 | 37.2 | 19.4 | 184 | 3900 |
Adelie | 1 | 36.2 | 16.1 | 187 | 3550 |
Adelie | 1 | 42.1 | 19.1 | 195 | 4000 |
Adelie | 1 | 34.6 | 17.2 | 189 | 3200 |
Adelie | 1 | 42.9 | 17.6 | 196 | 4700 |
Adelie | 1 | 36.7 | 18.8 | 187 | 3800 |
Adelie | 1 | 35.1 | 19.4 | 193 | 4200 |
Adelie | 1 | 37.3 | 17.8 | 191 | 3350 |
Adelie | 1 | 41.3 | 20.3 | 194 | 3550 |
Adelie | 1 | 36.3 | 19.5 | 190 | 3800 |
Adelie | 1 | 36.9 | 18.6 | 189 | 3500 |
Adelie | 1 | 38.3 | 19.2 | 189 | 3950 |
Adelie | 1 | 38.9 | 18.8 | 190 | 3600 |
Adelie | 1 | 35.7 | 18.0 | 202 | 3550 |
Adelie | 1 | 41.1 | 18.1 | 205 | 4300 |
Adelie | 1 | 34.0 | 17.1 | 185 | 3400 |
Adelie | 1 | 39.6 | 18.1 | 186 | 4450 |
Adelie | 1 | 36.2 | 17.3 | 187 | 3300 |
Adelie | 1 | 40.8 | 18.9 | 208 | 4300 |
Adelie | 1 | 38.1 | 18.6 | 190 | 3700 |
Adelie | 1 | 40.3 | 18.5 | 196 | 4350 |
Adelie | 1 | 33.1 | 16.1 | 178 | 2900 |
Adelie | 1 | 43.2 | 18.5 | 192 | 4100 |
Adelie | 1 | 35.0 | 17.9 | 192 | 3725 |
Adelie | 1 | 41.0 | 20.0 | 203 | 4725 |
Adelie | 1 | 37.7 | 16.0 | 183 | 3075 |
Adelie | 1 | 37.8 | 20.0 | 190 | 4250 |
Adelie | 1 | 37.9 | 18.6 | 193 | 2925 |
Adelie | 1 | 39.7 | 18.9 | 184 | 3550 |
Adelie | 1 | 38.6 | 17.2 | 199 | 3750 |
Adelie | 1 | 38.2 | 20.0 | 190 | 3900 |
Adelie | 1 | 38.1 | 17.0 | 181 | 3175 |
Adelie | 1 | 43.2 | 19.0 | 197 | 4775 |
Adelie | 1 | 38.1 | 16.5 | 198 | 3825 |
Adelie | 1 | 45.6 | 20.3 | 191 | 4600 |
Adelie | 1 | 39.7 | 17.7 | 193 | 3200 |
Adelie | 1 | 42.2 | 19.5 | 197 | 4275 |
Adelie | 1 | 39.6 | 20.7 | 191 | 3900 |
Adelie | 1 | 42.7 | 18.3 | 196 | 4075 |
Adelie | 1 | 38.6 | 17.0 | 188 | 2900 |
Adelie | 1 | 37.3 | 20.5 | 199 | 3775 |
Adelie | 1 | 35.7 | 17.0 | 189 | 3350 |
Adelie | 1 | 41.1 | 18.6 | 189 | 3325 |
Adelie | 1 | 36.2 | 17.2 | 187 | 3150 |
Adelie | 1 | 37.7 | 19.8 | 198 | 3500 |
Adelie | 1 | 40.2 | 17.0 | 176 | 3450 |
Adelie | 1 | 41.4 | 18.5 | 202 | 3875 |
Adelie | 1 | 35.2 | 15.9 | 186 | 3050 |
Adelie | 1 | 40.6 | 19.0 | 199 | 4000 |
Adelie | 1 | 38.8 | 17.6 | 191 | 3275 |
Adelie | 1 | 41.5 | 18.3 | 195 | 4300 |
Adelie | 1 | 39.0 | 17.1 | 191 | 3050 |
Adelie | 1 | 44.1 | 18.0 | 210 | 4000 |
Adelie | 1 | 38.5 | 17.9 | 190 | 3325 |
Adelie | 1 | 43.1 | 19.2 | 197 | 3500 |
Adelie | 1 | 36.8 | 18.5 | 193 | 3500 |
Adelie | 1 | 37.5 | 18.5 | 199 | 4475 |
Adelie | 1 | 38.1 | 17.6 | 187 | 3425 |
Adelie | 1 | 41.1 | 17.5 | 190 | 3900 |
Adelie | 1 | 35.6 | 17.5 | 191 | 3175 |
Adelie | 1 | 40.2 | 20.1 | 200 | 3975 |
Adelie | 1 | 37.0 | 16.5 | 185 | 3400 |
Adelie | 1 | 39.7 | 17.9 | 193 | 4250 |
Adelie | 1 | 40.2 | 17.1 | 193 | 3400 |
Adelie | 1 | 40.6 | 17.2 | 187 | 3475 |
Adelie | 1 | 32.1 | 15.5 | 188 | 3050 |
Adelie | 1 | 40.7 | 17.0 | 190 | 3725 |
Adelie | 1 | 37.3 | 16.8 | 192 | 3000 |
Adelie | 1 | 39.0 | 18.7 | 185 | 3650 |
Adelie | 1 | 39.2 | 18.6 | 190 | 4250 |
Adelie | 1 | 36.6 | 18.4 | 184 | 3475 |
Adelie | 1 | 36.0 | 17.8 | 195 | 3450 |
Adelie | 1 | 37.8 | 18.1 | 193 | 3750 |
Adelie | 1 | 36.0 | 17.1 | 187 | 3700 |
Adelie | 1 | 41.5 | 18.5 | 201 | 4000 |
Gentoo | 0 | 46.1 | 13.2 | 211 | 4500 |
Gentoo | 0 | 50.0 | 16.3 | 230 | 5700 |
Gentoo | 0 | 48.7 | 14.1 | 210 | 4450 |
Gentoo | 0 | 50.0 | 15.2 | 218 | 5700 |
Gentoo | 0 | 47.6 | 14.5 | 215 | 5400 |
Gentoo | 0 | 46.5 | 13.5 | 210 | 4550 |
Gentoo | 0 | 45.4 | 14.6 | 211 | 4800 |
Gentoo | 0 | 46.7 | 15.3 | 219 | 5200 |
Gentoo | 0 | 43.3 | 13.4 | 209 | 4400 |
Gentoo | 0 | 46.8 | 15.4 | 215 | 5150 |
Gentoo | 0 | 40.9 | 13.7 | 214 | 4650 |
Gentoo | 0 | 49.0 | 16.1 | 216 | 5550 |
Gentoo | 0 | 45.5 | 13.7 | 214 | 4650 |
Gentoo | 0 | 48.4 | 14.6 | 213 | 5850 |
Gentoo | 0 | 45.8 | 14.6 | 210 | 4200 |
Gentoo | 0 | 49.3 | 15.7 | 217 | 5850 |
Gentoo | 0 | 42.0 | 13.5 | 210 | 4150 |
Gentoo | 0 | 49.2 | 15.2 | 221 | 6300 |
Gentoo | 0 | 46.2 | 14.5 | 209 | 4800 |
Gentoo | 0 | 48.7 | 15.1 | 222 | 5350 |
Gentoo | 0 | 50.2 | 14.3 | 218 | 5700 |
Gentoo | 0 | 45.1 | 14.5 | 215 | 5000 |
Gentoo | 0 | 46.5 | 14.5 | 213 | 4400 |
Gentoo | 0 | 46.3 | 15.8 | 215 | 5050 |
Gentoo | 0 | 42.9 | 13.1 | 215 | 5000 |
Gentoo | 0 | 46.1 | 15.1 | 215 | 5100 |
Gentoo | 0 | 44.5 | 14.3 | 216 | 4100 |
Gentoo | 0 | 47.8 | 15.0 | 215 | 5650 |
Gentoo | 0 | 48.2 | 14.3 | 210 | 4600 |
Gentoo | 0 | 50.0 | 15.3 | 220 | 5550 |
Gentoo | 0 | 47.3 | 15.3 | 222 | 5250 |
Gentoo | 0 | 42.8 | 14.2 | 209 | 4700 |
Gentoo | 0 | 45.1 | 14.5 | 207 | 5050 |
Gentoo | 0 | 59.6 | 17.0 | 230 | 6050 |
Gentoo | 0 | 49.1 | 14.8 | 220 | 5150 |
Gentoo | 0 | 48.4 | 16.3 | 220 | 5400 |
Gentoo | 0 | 42.6 | 13.7 | 213 | 4950 |
Gentoo | 0 | 44.4 | 17.3 | 219 | 5250 |
Gentoo | 0 | 44.0 | 13.6 | 208 | 4350 |
Gentoo | 0 | 48.7 | 15.7 | 208 | 5350 |
Gentoo | 0 | 42.7 | 13.7 | 208 | 3950 |
Gentoo | 0 | 49.6 | 16.0 | 225 | 5700 |
Gentoo | 0 | 45.3 | 13.7 | 210 | 4300 |
Gentoo | 0 | 49.6 | 15.0 | 216 | 4750 |
Gentoo | 0 | 50.5 | 15.9 | 222 | 5550 |
Gentoo | 0 | 43.6 | 13.9 | 217 | 4900 |
Gentoo | 0 | 45.5 | 13.9 | 210 | 4200 |
Gentoo | 0 | 50.5 | 15.9 | 225 | 5400 |
Gentoo | 0 | 44.9 | 13.3 | 213 | 5100 |
Gentoo | 0 | 45.2 | 15.8 | 215 | 5300 |
Gentoo | 0 | 46.6 | 14.2 | 210 | 4850 |
Gentoo | 0 | 48.5 | 14.1 | 220 | 5300 |
Gentoo | 0 | 45.1 | 14.4 | 210 | 4400 |
Gentoo | 0 | 50.1 | 15.0 | 225 | 5000 |
Gentoo | 0 | 46.5 | 14.4 | 217 | 4900 |
Gentoo | 0 | 45.0 | 15.4 | 220 | 5050 |
Gentoo | 0 | 43.8 | 13.9 | 208 | 4300 |
Gentoo | 0 | 45.5 | 15.0 | 220 | 5000 |
Gentoo | 0 | 43.2 | 14.5 | 208 | 4450 |
Gentoo | 0 | 50.4 | 15.3 | 224 | 5550 |
Gentoo | 0 | 45.3 | 13.8 | 208 | 4200 |
Gentoo | 0 | 46.2 | 14.9 | 221 | 5300 |
Gentoo | 0 | 45.7 | 13.9 | 214 | 4400 |
Gentoo | 0 | 54.3 | 15.7 | 231 | 5650 |
Gentoo | 0 | 45.8 | 14.2 | 219 | 4700 |
Gentoo | 0 | 49.8 | 16.8 | 230 | 5700 |
Gentoo | 0 | 46.2 | 14.4 | 214 | 4650 |
Gentoo | 0 | 49.5 | 16.2 | 229 | 5800 |
Gentoo | 0 | 43.5 | 14.2 | 220 | 4700 |
Gentoo | 0 | 50.7 | 15.0 | 223 | 5550 |
Gentoo | 0 | 47.7 | 15.0 | 216 | 4750 |
Gentoo | 0 | 46.4 | 15.6 | 221 | 5000 |
Gentoo | 0 | 48.2 | 15.6 | 221 | 5100 |
Gentoo | 0 | 46.5 | 14.8 | 217 | 5200 |
Gentoo | 0 | 46.4 | 15.0 | 216 | 4700 |
Gentoo | 0 | 48.6 | 16.0 | 230 | 5800 |
Gentoo | 0 | 47.5 | 14.2 | 209 | 4600 |
Gentoo | 0 | 51.1 | 16.3 | 220 | 6000 |
Gentoo | 0 | 45.2 | 13.8 | 215 | 4750 |
Gentoo | 0 | 45.2 | 16.4 | 223 | 5950 |
Gentoo | 0 | 49.1 | 14.5 | 212 | 4625 |
Gentoo | 0 | 52.5 | 15.6 | 221 | 5450 |
Gentoo | 0 | 47.4 | 14.6 | 212 | 4725 |
Gentoo | 0 | 50.0 | 15.9 | 224 | 5350 |
Gentoo | 0 | 44.9 | 13.8 | 212 | 4750 |
Gentoo | 0 | 50.8 | 17.3 | 228 | 5600 |
Gentoo | 0 | 43.4 | 14.4 | 218 | 4600 |
Gentoo | 0 | 51.3 | 14.2 | 218 | 5300 |
Gentoo | 0 | 47.5 | 14.0 | 212 | 4875 |
Gentoo | 0 | 52.1 | 17.0 | 230 | 5550 |
Gentoo | 0 | 47.5 | 15.0 | 218 | 4950 |
Gentoo | 0 | 52.2 | 17.1 | 228 | 5400 |
Gentoo | 0 | 45.5 | 14.5 | 212 | 4750 |
Gentoo | 0 | 49.5 | 16.1 | 224 | 5650 |
Gentoo | 0 | 44.5 | 14.7 | 214 | 4850 |
Gentoo | 0 | 50.8 | 15.7 | 226 | 5200 |
Gentoo | 0 | 49.4 | 15.8 | 216 | 4925 |
Gentoo | 0 | 46.9 | 14.6 | 222 | 4875 |
Gentoo | 0 | 48.4 | 14.4 | 203 | 4625 |
Gentoo | 0 | 51.1 | 16.5 | 225 | 5250 |
Gentoo | 0 | 48.5 | 15.0 | 219 | 4850 |
Gentoo | 0 | 55.9 | 17.0 | 228 | 5600 |
Gentoo | 0 | 47.2 | 15.5 | 215 | 4975 |
Gentoo | 0 | 49.1 | 15.0 | 228 | 5500 |
Gentoo | 0 | 47.3 | 13.8 | 216 | 4725 |
Gentoo | 0 | 46.8 | 16.1 | 215 | 5500 |
Gentoo | 0 | 41.7 | 14.7 | 210 | 4700 |
Gentoo | 0 | 53.4 | 15.8 | 219 | 5500 |
Gentoo | 0 | 43.3 | 14.0 | 208 | 4575 |
Gentoo | 0 | 48.1 | 15.1 | 209 | 5500 |
Gentoo | 0 | 50.5 | 15.2 | 216 | 5000 |
Gentoo | 0 | 49.8 | 15.9 | 229 | 5950 |
Gentoo | 0 | 43.5 | 15.2 | 213 | 4650 |
Gentoo | 0 | 51.5 | 16.3 | 230 | 5500 |
Gentoo | 0 | 46.2 | 14.1 | 217 | 4375 |
Gentoo | 0 | 55.1 | 16.0 | 230 | 5850 |
Gentoo | 0 | 44.5 | 15.7 | 217 | 4875 |
Gentoo | 0 | 48.8 | 16.2 | 222 | 6000 |
Gentoo | 0 | 47.2 | 13.7 | 214 | 4925 |
Gentoo | 0 | 46.8 | 14.3 | 215 | 4850 |
Gentoo | 0 | 50.4 | 15.7 | 222 | 5750 |
Gentoo | 0 | 45.2 | 14.8 | 212 | 5200 |
Gentoo | 0 | 49.9 | 16.1 | 213 | 5400 |
The Hmisc::describe()
function can give us a quick summary of the data.
6 Variables 274 Observations
species
n | missing | distinct |
274 | 0 | 2 |
Value Adelie Gentoo Frequency 151 123 Proportion 0.551 0.449
adelie
n | missing | distinct | Info | Sum | Mean | Gmd |
274 | 0 | 2 | 0.742 | 151 | 0.5511 | 0.4966 |
bill_length_mm
n | missing | distinct | Info | Mean | Gmd | .05 | .10 | .25 | .50 | .75 | .90 | .95 |
274 | 0 | 146 | 1 | 42.7 | 5.944 | 35.43 | 36.20 | 38.35 | 42.00 | 46.68 | 49.80 | 50.73 |
bill_depth_mm
n | missing | distinct | Info | Mean | Gmd | .05 | .10 | .25 | .50 | .75 | .90 | .95 |
274 | 0 | 78 | 1 | 16.84 | 2.317 | 13.80 | 14.20 | 15.00 | 17.00 | 18.50 | 19.30 | 20.03 |
flipper_length_mm
n | missing | distinct | Info | Mean | Gmd | .05 | .10 | .25 | .50 | .75 | .90 | .95 |
274 | 0 | 54 | 0.999 | 202.2 | 17.23 | 181.0 | 184.0 | 190.0 | 198.0 | 215.0 | 222.0 | 226.7 |
body_mass_g
n | missing | distinct | Info | Mean | Gmd | .05 | .10 | .25 | .50 | .75 | .90 | .95 |
274 | 0 | 89 | 1 | 4318 | 962.3 | 3091 | 3282 | 3600 | 4262 | 4950 | 5535 | 5700 |
The below plot informs us that Adelie and Gentoo penguins are likely to be easily distinguishable based on the measured features, since there is little overlap between the two species. Because we want to have a bit of a challenge (and because logistic regression doesn’t converge if the classes are perfectly separable), we will predict species based on bill length and body mass.
Code
df %>%
GGally::ggpairs(mapping = aes(color=species),
columns = c("bill_length_mm",
"bill_depth_mm",
"flipper_length_mm",
"body_mass_g"),
title = "Can these features distinguish Adelie and Gentoo penguins?") +
scale_color_brewer(palette="Dark2") +
scale_fill_brewer(palette="Dark2")
In order to help our algorithms converge, we will put our variables on a more common scale by converting bill length to cm and body mass to kg.
Code
df$bill_length_cm <- df$bill_length_mm / 10
df$body_mass_kg <- df$body_mass_g / 1000
Logistic regression overview
Logistic regression is a type of linear model2. In statistics, a linear model means linear in the parameters, so we are modeling the output as a linear function of the parameters.
2 Specifically, logistic regression is a type of generalized linear model (GLM) along with probit regression, Poisson regression, and other common models used in data analysis.
For example, if we were predicting bill length, we could create a linear model where bill length is normally distributed, with a mean determined as a linear function of body mass and species.
\[\begin{gather} \mu_i = \beta_0 + \beta_1 [\text{Body Mass}]_i + \beta_2 [\text{Species = Adelie}]_i \\ [\text{Bill Length}]_i \sim N(\mu_i, \sigma^2) \end{gather}\]
We have three parameters, \(\beta_0\), \(\beta_1\), and \(\beta_2\). We can determine the likelihood of the data given these parameters. Maximizing the likelihood is the most common way to estimate the parameters from data. The idea is that we tune the parameters until we find the set of parameters that made the observed data most likely.
However, this won’t quite work if we want to predict a binary outcome like species. We could form a model like this: \[\begin{gather} p_i = \beta_0 + \beta_1 [\text{Bill Length}]_i + \beta_2 [\text{Body Mass}]_i \\ [\text{Species}]_i \sim \operatorname{Bernoulli}(p_i) \end{gather}\] In words, each observation of a penguin is modeled as a Bernoulli random variable, where the probability of being Adelie is a linear function of bill length and body mass. The issue with this model is that if we let the parameters vary, the value of \(p_i\) can exceed the range \([0,1]\), which doesn’t make sense if we are trying to model a probability.
The solution is using the expit function: \[ \operatorname{expit} = \frac{e^{x}}{1+e^{x}} \]
This function takes in a real valued input and transforms it to lie within the range \([0,1]\). The expit function is also called the logistic function, hence the name “logistic regression”.
Let’s try it out in an example.
Code
# This is a naive implementation that can overflow for large x
# expit <- function(x) exp(x) / (1 + exp(x))
# Better to use the built-in version
expit <- plogis
Code
We see that (approximately) anything below -5 gets squashed to zero and anything above 5 gets squashed to 1.
We then modify our model to be \[\begin{gather} p_i = \operatorname{expit} \left(\beta_0 + \beta_1 [\text{Bill Length}] + \beta_2 [\text{Body Mass}] \right) \\ \text{[Species]} \sim \operatorname{Bernoulli}(p_i) \end{gather}\] so now \(p_i\) is constrained to lie within \([0,1]\). This type of model, where we take a linear function of the parameters and then apply a non-linear function to it, is known as a generalized linear model (GLM).
The likelihood contribution of a single observation is \(p_i\) if it is Adelie and \(1-p_i\) if it is Gentoo, which we can write as
\[ \text{Likelihood}_i = p_i^{\text{Adelie}} (1-p_i)^{1 - \text{Adelie}} \]
This form of writing the Bernoulli PMF works because if Adelie = 1, then \(\text{Likelihood}_i = p_i^{1} (1-p_i)^{1 - 1} = p_i\) and if Adelie = 0, then \(\text{Likelihood}_i = p_i^{0} (1-p_i)^{1 - 0} = 1-p_i\).
and therefore the log-likelihood contribution of a single observation is
\[ \text{Log-Likelihood}_i = [\text{[Adelie]}_i \times \log(p_i)] + [(1 - \text{[Adelie]}_i) \times \log(1-p_i)] \]
The log-likelihood of the entire dataset is just the sum of all the individual log-likelihoods, since we are assuming independent observations, so we have
\[ \text{Log-Likelihood} = \sum_{i=1}^{n} \left[ [\text{[Adelie]}_i \times \log(p_i)] + [(1 - \text{[Adelie]}_i) \times \log(1-p_i)] \right] \]
and now substituting in \(p_i\) in terms of the parameters, we have \[\begin{align} \text{Log-Likelihood} &= \sum_{i=1}^{n} [ \underbrace{[\text{[Adelie]}_i \times \log(\operatorname{expit} \left(\beta_0 + \beta_1 [\text{Bill Length}]_i + \beta_2 [\text{Body Mass}]_i \right))]}_{\text{Contribution from Adelie observations}} \\ &+ \underbrace{[(1 - \text{[Adelie]}_i) \times \log(1-\operatorname{expit} \left(\beta_0 + \beta_1 [\text{Bill Length}]_i + \beta_2 [\text{Body Mass}]_i \right))]}_{\text{Contribution from Gentoo observations}}] \end{align}\]
We can then pick \(\beta_0\), \(\beta_1\), and \(\beta_2\) to maximize this log-likelihood function, or as is often done in practice, minimize the negative log-likelihood function. We will need to do this with numerical methods, rather than obtaining an analytical solution with calculus, since no closed-form solution exists.
Logistic regression with glm()
Before we implement logistic regression by hand, we will use the glm()
function in R as a baseline. Under the hood, R uses the Fisher Scoring Algorithm to obtain the maximum likelihood estimates.
Now that we have fit the model, let’s look at the predictions of the model. We can make a grid of covariate values, and ask the model to give us the predicted probability of Species = Adelie for each one.
We use the predict()
function to obtain the predicted probabilities. Using type = "response"
specifies that we want the predictions on the probability scale (i.e., after passing the linear predictor through the expit
function.).
Code
grid$predicted <- predict(model_glm, grid, type = "response")
bill_length_cm | body_mass_kg | predicted |
---|---|---|
5.287 | 6.105 | 0.0000001 |
5.755 | 2.840 | 0.0022277 |
5.657 | 5.825 | 0.0000000 |
4.462 | 4.905 | 0.0298950 |
4.519 | 6.065 | 0.0001169 |
6.130 | 3.090 | 0.0000257 |
6.092 | 4.370 | 0.0000001 |
5.616 | 6.050 | 0.0000000 |
5.533 | 4.675 | 0.0000055 |
3.442 | 5.815 | 0.8491884 |
4.557 | 3.765 | 0.6546862 |
4.358 | 3.360 | 0.9851913 |
3.477 | 6.020 | 0.6268662 |
5.550 | 3.805 | 0.0002095 |
4.230 | 3.785 | 0.9705454 |
5.127 | 6.315 | 0.0000002 |
5.961 | 3.870 | 0.0000039 |
5.010 | 4.880 | 0.0002480 |
5.167 | 4.430 | 0.0004301 |
4.263 | 4.605 | 0.4061394 |
4.620 | 2.835 | 0.9841835 |
5.678 | 4.025 | 0.0000254 |
5.131 | 5.045 | 0.0000406 |
4.408 | 3.405 | 0.9721137 |
5.389 | 3.790 | 0.0009517 |
5.665 | 3.765 | 0.0000886 |
5.761 | 5.505 | 0.0000000 |
3.427 | 5.460 | 0.9680849 |
5.227 | 4.880 | 0.0000352 |
4.676 | 4.290 | 0.0616915 |
3.896 | 5.735 | 0.1183568 |
5.180 | 4.980 | 0.0000347 |
5.415 | 5.100 | 0.0000025 |
6.023 | 3.875 | 0.0000022 |
5.420 | 6.360 | 0.0000000 |
4.552 | 4.990 | 0.0093732 |
6.091 | 4.235 | 0.0000002 |
3.870 | 4.635 | 0.9537189 |
3.853 | 6.050 | 0.0476244 |
4.815 | 6.290 | 0.0000031 |
3.967 | 5.675 | 0.0843093 |
3.115 | 5.355 | 0.9987433 |
3.299 | 4.345 | 0.9999197 |
4.809 | 3.710 | 0.1997304 |
4.837 | 4.660 | 0.0030633 |
5.581 | 5.730 | 0.0000000 |
3.797 | 4.505 | 0.9859334 |
4.694 | 2.875 | 0.9640952 |
3.320 | 4.840 | 0.9991593 |
4.583 | 2.790 | 0.9906229 |
4.185 | 5.960 | 0.0037194 |
3.292 | 5.390 | 0.9928432 |
3.282 | 3.575 | 0.9999976 |
4.709 | 3.975 | 0.1618652 |
4.154 | 5.960 | 0.0049102 |
4.552 | 2.830 | 0.9915448 |
4.398 | 5.740 | 0.0014320 |
3.052 | 6.130 | 0.9794326 |
6.015 | 3.260 | 0.0000344 |
4.159 | 5.855 | 0.0074037 |
4.354 | 3.515 | 0.9722746 |
5.859 | 4.545 | 0.0000005 |
4.807 | 3.205 | 0.6971030 |
5.838 | 5.110 | 0.0000001 |
4.537 | 3.715 | 0.7384344 |
4.324 | 3.945 | 0.8755539 |
3.760 | 4.185 | 0.9974751 |
4.820 | 3.090 | 0.7717730 |
4.623 | 5.070 | 0.0035106 |
4.036 | 6.190 | 0.0052033 |
6.145 | 4.945 | 0.0000000 |
4.555 | 4.835 | 0.0177903 |
6.054 | 3.085 | 0.0000520 |
4.706 | 3.920 | 0.2014218 |
5.669 | 2.995 | 0.0024554 |
4.568 | 4.730 | 0.0248441 |
5.785 | 3.630 | 0.0000543 |
3.620 | 3.245 | 0.9999881 |
5.188 | 4.825 | 0.0000635 |
5.144 | 6.030 | 0.0000005 |
4.999 | 3.615 | 0.0639715 |
6.022 | 6.100 | 0.0000000 |
4.351 | 3.780 | 0.9189363 |
6.035 | 4.170 | 0.0000005 |
4.533 | 4.725 | 0.0344495 |
5.689 | 4.160 | 0.0000127 |
5.407 | 5.570 | 0.0000003 |
3.911 | 4.500 | 0.9625201 |
3.749 | 4.335 | 0.9956077 |
5.552 | 5.745 | 0.0000000 |
5.122 | 6.180 | 0.0000003 |
3.918 | 4.625 | 0.9332259 |
5.408 | 2.965 | 0.0285411 |
4.192 | 4.515 | 0.6573858 |
4.271 | 4.180 | 0.8025823 |
5.068 | 5.340 | 0.0000198 |
5.500 | 3.905 | 0.0002124 |
4.897 | 5.295 | 0.0001121 |
4.881 | 5.770 | 0.0000163 |
3.734 | 6.205 | 0.0690677 |
3.611 | 3.320 | 0.9999848 |
4.886 | 6.345 | 0.0000013 |
4.080 | 3.250 | 0.9992383 |
4.767 | 4.825 | 0.0028003 |
4.130 | 2.910 | 0.9997289 |
5.421 | 4.875 | 0.0000063 |
3.676 | 3.510 | 0.9999375 |
5.557 | 3.725 | 0.0002789 |
5.320 | 4.835 | 0.0000185 |
6.129 | 5.795 | 0.0000000 |
3.630 | 3.565 | 0.9999475 |
4.129 | 3.040 | 0.9995264 |
4.154 | 5.745 | 0.0124515 |
4.196 | 3.350 | 0.9966620 |
5.988 | 2.820 | 0.0002992 |
5.690 | 5.445 | 0.0000000 |
5.984 | 6.300 | 0.0000000 |
3.060 | 3.565 | 0.9999997 |
4.456 | 6.390 | 0.0000499 |
5.438 | 6.260 | 0.0000000 |
4.113 | 4.120 | 0.9563162 |
5.717 | 5.045 | 0.0000002 |
3.565 | 5.730 | 0.7295421 |
4.147 | 5.665 | 0.0186824 |
5.093 | 4.645 | 0.0003276 |
4.597 | 4.685 | 0.0233260 |
3.761 | 4.955 | 0.9315152 |
5.658 | 6.305 | 0.0000000 |
4.344 | 4.125 | 0.7282171 |
5.474 | 3.295 | 0.0038291 |
4.507 | 2.870 | 0.9932731 |
5.661 | 5.665 | 0.0000000 |
5.026 | 3.465 | 0.0934974 |
3.544 | 3.975 | 0.9998551 |
3.012 | 4.840 | 0.9999474 |
4.975 | 6.395 | 0.0000005 |
4.090 | 5.605 | 0.0396736 |
5.575 | 3.145 | 0.0029717 |
5.280 | 2.835 | 0.1408342 |
3.980 | 5.415 | 0.2029946 |
4.258 | 5.770 | 0.0044150 |
3.699 | 4.095 | 0.9990138 |
4.102 | 5.925 | 0.0090953 |
3.902 | 3.265 | 0.9998360 |
5.447 | 6.125 | 0.0000000 |
3.768 | 4.290 | 0.9957173 |
5.557 | 3.495 | 0.0007604 |
3.714 | 5.900 | 0.2515647 |
3.159 | 5.185 | 0.9991104 |
3.248 | 5.330 | 0.9962798 |
3.642 | 5.820 | 0.4766904 |
6.032 | 5.700 | 0.0000000 |
5.133 | 4.245 | 0.0013080 |
4.436 | 3.130 | 0.9890056 |
3.985 | 4.550 | 0.9138578 |
3.517 | 2.930 | 0.9999988 |
3.224 | 5.535 | 0.9926940 |
4.313 | 6.095 | 0.0006542 |
3.398 | 3.685 | 0.9999890 |
5.902 | 4.725 | 0.0000002 |
3.988 | 5.290 | 0.2902346 |
4.044 | 2.950 | 0.9998511 |
4.945 | 4.070 | 0.0150290 |
5.294 | 5.270 | 0.0000035 |
5.089 | 3.955 | 0.0068499 |
5.898 | 3.965 | 0.0000046 |
5.198 | 5.745 | 0.0000010 |
4.600 | 3.155 | 0.9485557 |
4.088 | 4.580 | 0.7864920 |
6.091 | 5.140 | 0.0000000 |
4.000 | 2.960 | 0.9998953 |
5.637 | 4.175 | 0.0000191 |
4.899 | 4.180 | 0.0140821 |
4.033 | 4.820 | 0.6795311 |
4.439 | 4.805 | 0.0553886 |
5.062 | 6.345 | 0.0000003 |
4.060 | 5.130 | 0.3006920 |
3.703 | 3.485 | 0.9999285 |
4.389 | 5.585 | 0.0030489 |
4.853 | 3.350 | 0.4469271 |
5.445 | 2.915 | 0.0255251 |
3.892 | 3.755 | 0.9987300 |
4.750 | 4.125 | 0.0648959 |
5.812 | 3.925 | 0.0000117 |
5.153 | 5.090 | 0.0000274 |
5.506 | 6.315 | 0.0000000 |
5.427 | 3.025 | 0.0187022 |
4.570 | 4.230 | 0.1814927 |
4.618 | 3.130 | 0.9459126 |
3.347 | 5.420 | 0.9867009 |
3.997 | 5.325 | 0.2445391 |
5.281 | 3.210 | 0.0306600 |
4.710 | 2.955 | 0.9425279 |
3.492 | 5.755 | 0.8234821 |
5.227 | 6.175 | 0.0000001 |
4.670 | 4.955 | 0.0037975 |
5.405 | 3.075 | 0.0183351 |
3.841 | 4.875 | 0.9037304 |
5.406 | 4.850 | 0.0000080 |
4.031 | 5.600 | 0.0669916 |
Now that we have the predictions, let’s plot them and overlay the data with their true labels. The model looks to be performing pretty well!
Code
grid %>%
ggplot() +
aes(x=bill_length_cm,
y=body_mass_kg) +
geom_raster(aes(fill=predicted)) +
geom_point(data=df, mapping = aes(color=species)) +
geom_point(data=df, color="black", shape=21) +
scale_fill_viridis_c(breaks = seq(0, 1, 0.25),
limits=c(0,1)) +
scale_color_brewer(palette="Dark2") +
scale_x_continuous(breaks=seq(3, 6, 0.5)) +
scale_y_continuous(breaks=seq(3, 6, 0.5)) +
labs(fill = "Probability of Adelie\n",
color = "Species",
x = "Bill length (mm)",
y = "Body mass (g)",
title = "Visualizing the predictions of the logistic regression model") +
theme(legend.key.height = unit(1, "cm"))
Logistic regression with optim()
Now that we know what to expect after using glm()
, let’s implement logistic regression by hand.
Recall that we would like to numerically determine the beta values that minimize the negative log-likelihood. The optim()
function in R is a general-purpose function for minimizing functions3.
3 This short video is a good introduction to optim()
.
4 Sourced from here
optim()
has an algorithm called Nelder-Mead that searches the parameter space and converges on the minimum value. It is a direct search method that only requires the negative log-likelihood function as input (as opposed to gradient based methods that require specified the gradients of the negative log-likelihood function). This animation demonstrates the Nelder-Mead algorithm in action4.
To use optim()
, we create a function that takes as input the parameters and returns the negative log-likelihood. The below code is a translation of the mathematical notation from above.
Code
neg_loglikelihood_function <- function(parameters){
# optim() expects the parameters as a single vector, so we set the coefficients
# as the elements of a vector called `parameters`
b0 <- parameters[1]
b1 <- parameters[2]
b2 <- parameters[3]
linear_predictor <- (b0) + (b1*df$bill_length_cm) + (b2*df$body_mass_kg)
# Likelihood for each observation
# If the observation is Adelie, then the likelihood is the probability of Adelie
# If the observation is not Adelie (i.e., Gentoo), then the likelihood is the probability of not Adelie
# which is 1 - P(Adelie)
likelihood <- ifelse(df$adelie==1,
expit(linear_predictor),
1-expit(linear_predictor))
# Log-likelihood for each observation
log_likelihood <- log(likelihood)
# Joint log-likelihood for all the observations. Note the sum because
# multiplication is addition on the log-scale
total_log_likelihood <- sum(log_likelihood)
# the optim() function only minimizes, so we return the negative log-likelihood
# and then maximize it
return(-total_log_likelihood)
}
As an example, we can pass in \(\beta_0 = 1, \beta_1 = 2, \beta_2 = 3\) and see what the negative log-likelihood is.
Code
neg_loglikelihood_function(c(1,2,3))
[1] 3164.666
Because the negative log-likelihood is very high, we know that these are poor choices for the parameter values.
We can visualize the negative log-likelihood function for a variety of values. Since there are 3 parameters in our model, and we cannot visualize in 4D, we set \(\beta_0 = 58.075\), which was the optimized value found by glm()
and we can visualize how the negative log-likelihood varies with \(\beta_1\) and \(\beta_2\).
Code
# Show a heatmap of the negative log-likelihood with contour lines
grid %>%
ggplot() +
aes(x=b1,
y=b2) +
geom_raster(aes(fill=neg_loglikelihood)) +
geom_contour(aes(z=neg_loglikelihood), bins = 50, size=0.1, color="gray") +
scale_fill_viridis_c() +
scale_color_brewer(palette="Dark2") +
annotate(geom="point", x=-8.999, y=-4.363, color="red") +
labs(fill = "Negative log-likelihood\n",
x = "Beta 1",
y = "Beta 2",
title = "Visualizing the negative log-likelihood function") +
theme(legend.key.height = unit(1, "cm"))
To use optim()
, we pass in the starting parameter values to par
and the function to be minimized (the negative log-likelihood) to fn
. Finally, we’ll specify method="Nelder-Mead"
.
The maximum likelihood estimates are stored in the $par
attribute of the optim
object
Code
optim_results$par
[1] 58.080999 -9.001399 -4.362164
which we can compare with the coefficients obtained from glm()
, and we see that they match quite closely.
Code
coef(model_glm)
(Intercept) bill_length_cm body_mass_kg
58.074991 -8.998692 -4.363412
The below animation demonstrates the path of the Nelder-Mead function5. As stated above, for the purpose of the animation, we set the optimized value of \(\beta_0 = 58.075\) and we can visualize how the negative log-likelihood is optimized with respect to \(\beta_1\) and \(\beta_2\).
5 The code for this animation is long, so it is not included here, but can be viewed in the source code of the Quarto document.
Logistic regression with gradient descent
Going one step further, instead of using a built-in optimization algorithm, let’s maximize the likelihood ourselves using gradient descent. If you need a refresher, I have written a blog post on gradient descent which you can find here.
We need the gradient of the negative-log likelihood function. The slope with respect to the jth parameter is given by
\[\begin{align} [\operatorname{expit}(\mathbf{\beta} \cdot \mathbf{x})-\mathbf{y}] \mathbf{x}_{j} \implies [\hat{\mathbf{y}}-\mathbf{y}] \mathbf{x}_{j} \end{align}\] so then the gradient can be written as \[\begin{align} \mathbf{X}^T [\operatorname{expit}(\mathbf{X} \mathbf{\beta}) - \mathbf{y}] \end{align}\] or equivalently \[\begin{align} \mathbf{X}^T (\hat{\mathbf{y}} - \mathbf{y}) \end{align}\]
You can find a nice derivation of the derivative of the negative log-likelihood for logistic regression here.
Another approach is to use automatic differentiation. Automatic differentiation can be used to obtain gradients for arbitrary functions, and is used heavily in deep learning. An example to do this in R using the torch
library is shown here.
We implement the above equations in the following function for the gradient.
Code
gradient <- function(parameters){
# Given a vector of parameters values, return the current gradient
b0 <- parameters[1]
b1 <- parameters[2]
b2 <- parameters[3]
# Define design matrix
X <- cbind(rep(1, nrow(df)),
df$bill_length_cm,
df$body_mass_kg)
beta <- matrix(parameters)
y_hat <- expit(X %*% beta)
gradient <- t(X) %*% (y_hat - df$adelie)
return(gradient)
}
We must specify type="2"
in the norm()
function to specify that we want the Euclidean length of the vector.
Now we implement the gradient descent algorithm. We stop if the difference between the new parameter vector and old parameter vector is less than \(10^{-6}\).
Code
set.seed(777)
step_size <- 0.001 # Learning rate
theta <- c(0,0,0) # Initial parameter value
iter <- 1
while (TRUE) {
iter <- iter + 1
current_gradient <- gradient(theta)
theta_new <- theta - (step_size * current_gradient)
if (norm(theta - theta_new, type="2") < 1e-6) {
break
} else{
theta <- theta_new
}
}
Final parameter values: 57.9692967372787
Final parameter values: -8.98174736147028
Final parameter values: -4.35607256440227
`glm()` parameter values (for comparison): 58.0749906489601
`glm()` parameter values (for comparison): -8.99869242178258
`glm()` parameter values (for comparison): -4.36341215126915
Again, we see that the results are very close to the glm()
results.
Conclusion
Hopefully this post was helpful for understanding the inner workings of logistic regression and how the principles can be extended to other types of models. For example, Poisson regression is another type of generalized linear model just like logistic regression, where in that case we use the exp
function instead of the expit
function to constrain parameter values to lie in the range \([0, \infty]\).