In the previous step we explored the data and created plots using the tidyverse
and htmlwidgets
. In this step we will create two statistical models that predict income over $50k
. The first model is a classic logistic model and the second model uses the elastic net with regularization L1 and L2 penalties. All predictors will be categorical in this analysis.
library(tidyverse)
library(plotly)
library(pROC)
library(glmnet)
train <- read_csv("data/train.csv")
test <- read_csv("data/test.csv")
The logistic model uses main effects only against the training data. No regularization is applied. We assess the model fit with a hold out sample. We can build logistic models with the stats
package.
Gender, education, and marital status are all highly significant. Marrital status in particular is a good predictor of those earning more than $50k.
m1 <- glm(label ~ gender + native_country + education + occupation + workclass + marital_status +
race + age_buckets, binomial, train)
summary(m1)
Call:
glm(formula = label ~ gender + native_country + education + occupation +
workclass + marital_status + race + age_buckets, family = binomial,
data = train)
Deviance Residuals:
Min 1Q Median 3Q Max
-2.6001 -0.5627 -0.2295 -0.0001 3.8200
Coefficients:
Estimate Std. Error z value Pr(>|z|)
(Intercept) -16.15783 81.90324 -0.197 0.84361
genderMale 0.30583 0.05027 6.083 1.18e-09 ***
native_countryCanada -0.67372 0.67788 -0.994 0.32029
native_countryChina -1.78499 0.69178 -2.580 0.00987 **
native_countryColumbia -3.14179 1.00637 -3.122 0.00180 **
native_countryCuba -0.78317 0.68906 -1.137 0.25572
native_countryDominican-Republic -2.13266 0.98598 -2.163 0.03054 *
native_countryEcuador -1.22454 0.93018 -1.316 0.18802
native_countryEl-Salvador -1.39951 0.76844 -1.821 0.06857 .
native_countryEngland -0.53931 0.69006 -0.782 0.43449
native_countryFrance -0.38991 0.81306 -0.480 0.63154
native_countryGermany -0.57296 0.66658 -0.860 0.39003
native_countryGreece -1.70925 0.81647 -2.093 0.03631 *
native_countryGuatemala -1.13839 0.92321 -1.233 0.21755
native_countryHaiti -1.31481 0.88943 -1.478 0.13934
native_countryHoland-Netherlands -13.77494 2399.54480 -0.006 0.99542
native_countryHonduras -1.77411 1.75548 -1.011 0.31220
native_countryHong -1.21231 0.87454 -1.386 0.16568
native_countryHungary -0.88362 0.97130 -0.910 0.36297
native_countryIndia -1.53554 0.66047 -2.325 0.02008 *
native_countryIran -0.93288 0.73550 -1.268 0.20467
native_countryIreland -0.39039 0.87096 -0.448 0.65399
native_countryItaly -0.28367 0.69837 -0.406 0.68460
native_countryJamaica -1.14196 0.74807 -1.527 0.12688
native_countryJapan -0.84284 0.71013 -1.187 0.23527
native_countryLaos -1.67837 1.08622 -1.545 0.12231
native_countryMexico -1.49385 0.65548 -2.279 0.02267 *
native_countryNicaragua -1.64888 1.01535 -1.624 0.10439
native_countryOutlying-US(Guam-USVI-etc) -15.24852 579.46692 -0.026 0.97901
native_countryPeru -2.00894 1.01398 -1.981 0.04756 *
native_countryPhilippines -0.84343 0.63640 -1.325 0.18506
native_countryPoland -1.12283 0.73874 -1.520 0.12853
native_countryPortugal -1.08854 0.89048 -1.222 0.22155
native_countryPuerto-Rico -1.36009 0.72420 -1.878 0.06037 .
native_countryScotland -1.52519 1.11662 -1.366 0.17197
native_countrySouth -1.97837 0.70600 -2.802 0.00508 **
native_countryTaiwan -1.43724 0.73488 -1.956 0.05049 .
native_countryThailand -1.47170 0.99100 -1.485 0.13753
native_countryTrinadad&Tobago -1.47980 1.01391 -1.460 0.14443
native_countryUnited-States -0.80574 0.62400 -1.291 0.19661
native_countryVietnam -2.01444 0.82255 -2.449 0.01432 *
native_countryYugoslavia -0.27549 0.89965 -0.306 0.75944
education11th 0.10758 0.20589 0.523 0.60131
education12th 0.51047 0.26358 1.937 0.05278 .
education1st-4th -0.57722 0.47004 -1.228 0.21943
education5th-6th -0.41266 0.34977 -1.180 0.23809
education7th-8th -0.45579 0.23369 -1.950 0.05113 .
education9th -0.31743 0.26050 -1.219 0.22302
educationAssoc-acdm 1.20443 0.17195 7.004 2.48e-12 ***
educationAssoc-voc 1.20440 0.16503 7.298 2.92e-13 ***
educationBachelors 1.90172 0.15370 12.373 < 2e-16 ***
educationDoctorate 3.10123 0.20970 14.789 < 2e-16 ***
educationHS-grad 0.73165 0.14945 4.896 9.80e-07 ***
educationMasters 2.28057 0.16371 13.930 < 2e-16 ***
educationPreschool -13.08227 301.13781 -0.043 0.96535
educationProf-school 3.05504 0.19480 15.683 < 2e-16 ***
educationSome-college 1.06187 0.15168 7.001 2.55e-12 ***
occupationArmed-Forces -0.87568 1.41321 -0.620 0.53550
occupationCraft-repair -0.03640 0.07594 -0.479 0.63172
occupationExec-managerial 0.86970 0.07205 12.072 < 2e-16 ***
occupationFarming-fishing -0.74541 0.12892 -5.782 7.38e-09 ***
occupationHandlers-cleaners -0.76078 0.13892 -5.477 4.34e-08 ***
occupationMachine-op-inspct -0.38296 0.09814 -3.902 9.53e-05 ***
occupationOther-service -0.91502 0.11293 -8.102 5.39e-16 ***
occupationPriv-house-serv -2.16427 1.02326 -2.115 0.03442 *
occupationProf-specialty 0.53786 0.07667 7.015 2.30e-12 ***
occupationProtective-serv 0.65652 0.12144 5.406 6.43e-08 ***
occupationSales 0.36815 0.07736 4.759 1.95e-06 ***
occupationTech-support 0.55767 0.10697 5.213 1.86e-07 ***
occupationTransport-moving -0.08087 0.09426 -0.858 0.39092
workclassLocal-gov -0.64511 0.10729 -6.013 1.82e-09 ***
workclassPrivate -0.35831 0.08925 -4.015 5.95e-05 ***
workclassSelf-emp-inc 0.04159 0.11751 0.354 0.72340
workclassSelf-emp-not-inc -0.74070 0.10435 -7.098 1.26e-12 ***
workclassState-gov -0.84546 0.11962 -7.068 1.57e-12 ***
workclassWithout-pay -14.96782 542.24463 -0.028 0.97798
marital_statusMarried-AF-spouse 3.24802 0.48912 6.640 3.13e-11 ***
marital_statusMarried-civ-spouse 2.10784 0.06209 33.949 < 2e-16 ***
marital_statusMarried-spouse-absent 0.09028 0.21530 0.419 0.67500
marital_statusNever-married -0.18746 0.07735 -2.424 0.01537 *
marital_statusSeparated -0.10303 0.14811 -0.696 0.48666
marital_statusWidowed 0.33062 0.13960 2.368 0.01787 *
raceAsian-Pac-Islander 0.68870 0.26582 2.591 0.00957 **
raceBlack 0.41753 0.22289 1.873 0.06103 .
raceOther 0.02531 0.35842 0.071 0.94369
raceWhite 0.56547 0.21275 2.658 0.00786 **
age_buckets(18,25] 11.13552 81.90046 0.136 0.89185
age_buckets(25,30] 12.21290 81.90042 0.149 0.88146
age_buckets(30,35] 12.66809 81.90042 0.155 0.87708
age_buckets(35,40] 13.10329 81.90042 0.160 0.87289
age_buckets(40,45] 13.16779 81.90042 0.161 0.87227
age_buckets(45,50] 13.33564 81.90042 0.163 0.87065
age_buckets(50,55] 13.36368 81.90043 0.163 0.87038
age_buckets(55,60] 13.19114 81.90043 0.161 0.87204
age_buckets(60,65] 12.76302 81.90046 0.156 0.87616
age_buckets(65,90] 12.39737 81.90047 0.151 0.87968
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
(Dispersion parameter for binomial family taken to be 1)
Null deviance: 33851 on 30161 degrees of freedom
Residual deviance: 21510 on 30066 degrees of freedom
AIC: 21702
Number of Fisher Scoring iterations: 15
anova(m1) # takes a while to run
Analysis of Deviance Table
Model: binomial, link: logit
Response: label
Terms added sequentially (first to last)
Df Deviance Resid. Df Resid. Dev
NULL 30161 33851
gender 1 1564.1 30160 32287
native_country 40 403.3 30120 31883
education 15 3674.7 30105 28209
occupation 13 1302.8 30092 26906
workclass 6 190.1 30086 26716
marital_status 6 4334.1 30080 22382
race 4 16.6 30076 22365
age_buckets 10 855.1 30066 21510
#plot(m1) # legacy plots not that useful
The high area under the curve (AUC) of 0.883 indicator that this model might be overfitting. The lift chart shows that 80% of those in the uppper decile earn more than $50k
, compared to a tiny fraction in the lower decile.
# Predict
pred <- bind_rows("train" = train, "test" = test, .id = "data") %>%
mutate(pred = predict(m1, ., type = "response")) %>%
mutate(decile = ntile(desc(pred), 10)) %>%
select(data, label, pred, decile)
# ROC plot
pred %>%
filter(data == "test") %>%
roc(label ~ pred, .) %>%
plot.roc(., print.auc = TRUE)
# Lift plot
p <- pred %>%
group_by(data, decile) %>%
summarize(percent = 100 * mean(label)) %>%
ggplot(aes(decile, percent, fill = data)) + geom_bar(stat = "Identity", position = "dodge") +
ggtitle("Lift chart for logistic regression model")
ggplotly(p)
The elastic net is a regularized regression method that uses L1 (lasso) and L2 (ridge) penalties. We can build elastic net models with the glmnet
package.
Whereas the logistic method used formulas, the elastic net model requires us to construct a model matrix from the categorical predictors. We then attempt to choose a value lambda
that optimizes the L1 and L2 penalties. We can examine various predictor sets for different values of lambda
. Optionally, we can use cross validation to programmatically determine the best choice of lambda
.
# Convert to factors
alldata <- bind_rows("train" = train, "test" = test, .id = "data") %>%
select(-education_num) %>%
mutate_each(., funs(factor(.))) %>%
model.matrix( ~ ., .)
`mutate_each()` is deprecated.
Use `mutate_all()`, `mutate_at()` or `mutate_if()` instead.
To map `funs` over all variables, use `mutate_all()`
# Create training prediction matrix
train.factors <- list(x = alldata[alldata[,'datatrain'] == 1, -(1:3)],
y = alldata[alldata[,'datatrain'] == 1, 3])
# Create test prediction matrix
test.factors <- list(x = alldata[alldata[,'datatrain'] == 0, -(1:3)],
y = alldata[alldata[,'datatrain'] == 0, 3])
# Fit a regularized model
fit1 <- glmnet(train.factors$x, train.factors$y, family = "binomial")
plot(fit1)
print(fit1)
Call: glmnet(x = train.factors$x, y = train.factors$y, family = "binomial")
Df %Dev Lambda
[1,] 0 -2.170e-13 0.1926000
[2,] 1 2.996e-02 0.1755000
[3,] 1 5.485e-02 0.1599000
[4,] 1 7.566e-02 0.1457000
[5,] 1 9.312e-02 0.1327000
[6,] 1 1.078e-01 0.1210000
[7,] 1 1.202e-01 0.1102000
[8,] 1 1.308e-01 0.1004000
[9,] 1 1.396e-01 0.0915000
[10,] 1 1.472e-01 0.0833700
[11,] 2 1.564e-01 0.0759600
[12,] 3 1.710e-01 0.0692100
[13,] 5 1.852e-01 0.0630700
[14,] 5 1.989e-01 0.0574600
[15,] 7 2.121e-01 0.0523600
[16,] 7 2.246e-01 0.0477100
[17,] 8 2.356e-01 0.0434700
[18,] 8 2.461e-01 0.0396100
[19,] 8 2.551e-01 0.0360900
[20,] 10 2.629e-01 0.0328800
[21,] 11 2.713e-01 0.0299600
[22,] 13 2.788e-01 0.0273000
[23,] 14 2.855e-01 0.0248700
[24,] 14 2.913e-01 0.0226600
[25,] 17 2.970e-01 0.0206500
[26,] 17 3.023e-01 0.0188200
[27,] 23 3.074e-01 0.0171400
[28,] 26 3.127e-01 0.0156200
[29,] 27 3.177e-01 0.0142300
[30,] 28 3.221e-01 0.0129700
[31,] 30 3.261e-01 0.0118200
[32,] 34 3.300e-01 0.0107700
[33,] 35 3.337e-01 0.0098110
[34,] 35 3.369e-01 0.0089390
[35,] 37 3.396e-01 0.0081450
[36,] 38 3.421e-01 0.0074220
[37,] 39 3.443e-01 0.0067620
[38,] 40 3.463e-01 0.0061620
[39,] 40 3.479e-01 0.0056140
[40,] 42 3.496e-01 0.0051150
[41,] 43 3.512e-01 0.0046610
[42,] 44 3.526e-01 0.0042470
[43,] 47 3.538e-01 0.0038700
[44,] 50 3.548e-01 0.0035260
[45,] 51 3.557e-01 0.0032130
[46,] 54 3.565e-01 0.0029270
[47,] 57 3.572e-01 0.0026670
[48,] 60 3.578e-01 0.0024300
[49,] 63 3.583e-01 0.0022140
[50,] 65 3.588e-01 0.0020180
[51,] 66 3.592e-01 0.0018380
[52,] 67 3.597e-01 0.0016750
[53,] 71 3.601e-01 0.0015260
[54,] 73 3.604e-01 0.0013910
[55,] 74 3.607e-01 0.0012670
[56,] 74 3.610e-01 0.0011550
[57,] 77 3.612e-01 0.0010520
[58,] 78 3.615e-01 0.0009585
[59,] 80 3.618e-01 0.0008734
[60,] 81 3.621e-01 0.0007958
[61,] 81 3.623e-01 0.0007251
[62,] 81 3.624e-01 0.0006607
[63,] 82 3.625e-01 0.0006020
[64,] 82 3.627e-01 0.0005485
[65,] 84 3.628e-01 0.0004998
[66,] 88 3.628e-01 0.0004554
[67,] 90 3.629e-01 0.0004149
[68,] 91 3.631e-01 0.0003781
[69,] 91 3.633e-01 0.0003445
[70,] 92 3.634e-01 0.0003139
[71,] 91 3.634e-01 0.0002860
[72,] 91 3.635e-01 0.0002606
[73,] 92 3.635e-01 0.0002374
[74,] 92 3.636e-01 0.0002163
[75,] 92 3.636e-01 0.0001971
[76,] 90 3.638e-01 0.0001796
[77,] 90 3.638e-01 0.0001637
(m2 <- coef.glmnet(fit1, s = 0.02)) # extract coefficients at a single value of lambda
96 x 1 sparse Matrix of class "dgCMatrix"
1
(Intercept) -2.48987926
genderMale .
native_countryCanada .
native_countryChina .
native_countryColumbia .
native_countryCuba .
native_countryDominican-Republic .
native_countryEcuador .
native_countryEl-Salvador .
native_countryEngland .
native_countryFrance .
native_countryGermany .
native_countryGreece .
native_countryGuatemala .
native_countryHaiti .
native_countryHoland-Netherlands .
native_countryHonduras .
native_countryHong .
native_countryHungary .
native_countryIndia .
native_countryIran .
native_countryIreland .
native_countryItaly .
native_countryJamaica .
native_countryJapan .
native_countryLaos .
native_countryMexico .
native_countryNicaragua .
native_countryOutlying-US(Guam-USVI-etc) .
native_countryPeru .
native_countryPhilippines .
native_countryPoland .
native_countryPortugal .
native_countryPuerto-Rico .
native_countryScotland .
native_countrySouth .
native_countryTaiwan .
native_countryThailand .
native_countryTrinadad&Tobago .
native_countryUnited-States .
native_countryVietnam .
native_countryYugoslavia .
education11th .
education12th .
education1st-4th .
education5th-6th .
education7th-8th -0.10172508
education9th .
educationAssoc-acdm .
educationAssoc-voc .
educationBachelors 0.69417178
educationDoctorate 1.04226523
educationHS-grad -0.03860506
educationMasters 0.89614956
educationPreschool .
educationProf-school 1.19493744
educationSome-college .
occupationArmed-Forces .
occupationCraft-repair .
occupationExec-managerial 0.69611385
occupationFarming-fishing -0.06129799
occupationHandlers-cleaners .
occupationMachine-op-inspct .
occupationOther-service -0.28702345
occupationPriv-house-serv .
occupationProf-specialty 0.43502928
occupationProtective-serv .
occupationSales .
occupationTech-support .
occupationTransport-moving .
workclassLocal-gov .
workclassPrivate .
workclassSelf-emp-inc 0.20665622
workclassSelf-emp-not-inc .
workclassState-gov .
workclassWithout-pay .
marital_statusMarried-AF-spouse .
marital_statusMarried-civ-spouse 1.85107323
marital_statusMarried-spouse-absent .
marital_statusNever-married -0.07063281
marital_statusSeparated .
marital_statusWidowed .
raceAsian-Pac-Islander .
raceBlack .
raceOther .
raceWhite .
age_buckets(18,25] -0.72088908
age_buckets(25,30] -0.24146074
age_buckets(30,35] .
age_buckets(35,40] .
age_buckets(40,45] .
age_buckets(45,50] 0.08753949
age_buckets(50,55] 0.01396170
age_buckets(55,60] .
age_buckets(60,65] .
age_buckets(65,90] .
# Cross validation (long running for full dataset)
cvfit <- cv.glmnet(train.factors$x, train.factors$y, family = "binomial", type.measure = "class")
plot(cvfit)
cvfit$lambda.min # 0.0001971255
Once you have chosen a value for lambda
you can score the test set and examine the ROC and lift charts. This model has a slightly smaller AUC and lift values, but the overall results look very similar to logistic regression.
# Predict and plot the AUC
test.factors$pred <- predict(fit1, test.factors$x, s=0.02, type = "response") # make predictions
data.frame(resp = test.factors$y, pred = c(test.factors$pred)) %>%
roc(resp ~ pred, .) %>%
plot.roc(., print.auc = TRUE)
# Lift chart
p <- data.frame(data = ifelse(alldata[, 'datatrain'], "train", "test"),
label = alldata[,'label1'],
pred = c(predict.glmnet(fit1, alldata[, -(1:3)], s=0.02))) %>%
mutate(decile = ntile(desc(pred), 10)) %>%
group_by(data, decile) %>%
summarize(percent = 100 * mean(label)) %>%
ggplot(aes(decile, percent, fill = data)) + geom_bar(stat = "Identity", position = "dodge") +
ggtitle("Lift chart for elastic net model")
ggplotly(p)
Finally, save the predicted output and the model for building apps.
# Score predictions
pred.out <- test %>%
mutate(pred.glm = pred$pred[pred$data == "test"]) %>%
mutate(pred.net = c(test.factors$pred)) %>%
mutate(income_bracket = ifelse(label, ">50K", "<=50K")
)
# Output predictions to file
write_csv(pred.out, "data/pred.csv")
saveRDS(m1, file = "data/logisticModel.rds")
saveRDS(m2, file = "data/elasticnetModel.rds")