Awhile ago, during my time at Centura Health, I began researching all the different ways we could analyze the data at hand. And at some point I became very interested in HR analytics. It was a fascinating way to explore an organization through data. Of course, while digesting the org structure through network graphs is interesting there are also more immediately clear benefits from employee data; predicting turnover.
I came across this article on using machine learning, specifically h2o and Lime, for turnover prediction and really liked the simplicity of the features included, as well as, the author’s explanation. After setting this example aside for some time, I finally found the time to work through this example and peeking under the h2o hood attempt to replicate the results using GLMNET.
Unfortunately, the link to the data in the original article is broken. After a little bit of Googling I was able to find some conscientious Github repos with a copy. I did some checks and the data seemed to match the same as the article. Since I no longer remember where I found the data, it’s hosted in my own repo now.
Here we’ll look at the employee data and convert characters to factors.
employee_attrition_raw %>%
head(10) %>%
knitr::kable(format = 'html') %>%
kableExtra::kable_styling() %>%
kableExtra::scroll_box(width = "100%", height = "750px")
Age | Attrition | BusinessTravel | DailyRate | Department | DistanceFromHome | Education | EducationField | EmployeeCount | EmployeeNumber | EnvironmentSatisfaction | Gender | HourlyRate | JobInvolvement | JobLevel | JobRole | JobSatisfaction | MaritalStatus | MonthlyIncome | MonthlyRate | NumCompaniesWorked | Over18 | OverTime | PercentSalaryHike | PerformanceRating | RelationshipSatisfaction | StandardHours | StockOptionLevel | TotalWorkingYears | TrainingTimesLastYear | WorkLifeBalance | YearsAtCompany | YearsInCurrentRole | YearsSinceLastPromotion | YearsWithCurrManager |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
41 | Yes | Travel_Rarely | 1102 | Sales | 1 | 2 | Life Sciences | 1 | 1 | 2 | Female | 94 | 3 | 2 | Sales Executive | 4 | Single | 5993 | 19479 | 8 | Y | Yes | 11 | 3 | 1 | 80 | 0 | 8 | 0 | 1 | 6 | 4 | 0 | 5 |
49 | No | Travel_Frequently | 279 | Research & Development | 8 | 1 | Life Sciences | 1 | 2 | 3 | Male | 61 | 2 | 2 | Research Scientist | 2 | Married | 5130 | 24907 | 1 | Y | No | 23 | 4 | 4 | 80 | 1 | 10 | 3 | 3 | 10 | 7 | 1 | 7 |
37 | Yes | Travel_Rarely | 1373 | Research & Development | 2 | 2 | Other | 1 | 4 | 4 | Male | 92 | 2 | 1 | Laboratory Technician | 3 | Single | 2090 | 2396 | 6 | Y | Yes | 15 | 3 | 2 | 80 | 0 | 7 | 3 | 3 | 0 | 0 | 0 | 0 |
33 | No | Travel_Frequently | 1392 | Research & Development | 3 | 4 | Life Sciences | 1 | 5 | 4 | Female | 56 | 3 | 1 | Research Scientist | 3 | Married | 2909 | 23159 | 1 | Y | Yes | 11 | 3 | 3 | 80 | 0 | 8 | 3 | 3 | 8 | 7 | 3 | 0 |
27 | No | Travel_Rarely | 591 | Research & Development | 2 | 1 | Medical | 1 | 7 | 1 | Male | 40 | 3 | 1 | Laboratory Technician | 2 | Married | 3468 | 16632 | 9 | Y | No | 12 | 3 | 4 | 80 | 1 | 6 | 3 | 3 | 2 | 2 | 2 | 2 |
32 | No | Travel_Frequently | 1005 | Research & Development | 2 | 2 | Life Sciences | 1 | 8 | 4 | Male | 79 | 3 | 1 | Laboratory Technician | 4 | Single | 3068 | 11864 | 0 | Y | No | 13 | 3 | 3 | 80 | 0 | 8 | 2 | 2 | 7 | 7 | 3 | 6 |
59 | No | Travel_Rarely | 1324 | Research & Development | 3 | 3 | Medical | 1 | 10 | 3 | Female | 81 | 4 | 1 | Laboratory Technician | 1 | Married | 2670 | 9964 | 4 | Y | Yes | 20 | 4 | 1 | 80 | 3 | 12 | 3 | 2 | 1 | 0 | 0 | 0 |
30 | No | Travel_Rarely | 1358 | Research & Development | 24 | 1 | Life Sciences | 1 | 11 | 4 | Male | 67 | 3 | 1 | Laboratory Technician | 3 | Divorced | 2693 | 13335 | 1 | Y | No | 22 | 4 | 2 | 80 | 1 | 1 | 2 | 3 | 1 | 0 | 0 | 0 |
38 | No | Travel_Frequently | 216 | Research & Development | 23 | 3 | Life Sciences | 1 | 12 | 4 | Male | 44 | 2 | 3 | Manufacturing Director | 3 | Single | 9526 | 8787 | 0 | Y | No | 21 | 4 | 2 | 80 | 0 | 10 | 2 | 3 | 9 | 7 | 1 | 8 |
36 | No | Travel_Rarely | 1299 | Research & Development | 27 | 3 | Medical | 1 | 13 | 3 | Male | 94 | 3 | 2 | Healthcare Representative | 3 | Married | 5237 | 16577 | 6 | Y | No | 13 | 3 | 2 | 80 | 2 | 17 | 3 | 2 | 7 | 7 | 7 | 7 |
# Change to factors
employee_attrition <- employee_attrition_raw %>%
dplyr::mutate_if(is.character, as.factor) %>%
dplyr::select(Attrition, dplyr::everything())
In total we have 1470 observations and 35 features.
tibble::glimpse(employee_attrition, width = 100)
## Observations: 1,470
## Variables: 35
## $ Attrition <fct> Yes, No, Yes, No, No, No, No, No, No, No, No, No, No, No, Yes, N…
## $ Age <dbl> 41, 49, 37, 33, 27, 32, 59, 30, 38, 36, 35, 29, 31, 34, 28, 29, …
## $ BusinessTravel <fct> Travel_Rarely, Travel_Frequently, Travel_Rarely, Travel_Frequent…
## $ DailyRate <dbl> 1102, 279, 1373, 1392, 591, 1005, 1324, 1358, 216, 1299, 809, 15…
## $ Department <fct> Sales, Research & Development, Research & Development, Research …
## $ DistanceFromHome <dbl> 1, 8, 2, 3, 2, 2, 3, 24, 23, 27, 16, 15, 26, 19, 24, 21, 5, 16, …
## $ Education <dbl> 2, 1, 2, 4, 1, 2, 3, 1, 3, 3, 3, 2, 1, 2, 3, 4, 2, 2, 4, 3, 2, 4…
## $ EducationField <fct> Life Sciences, Life Sciences, Other, Life Sciences, Medical, Lif…
## $ EmployeeCount <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…
## $ EmployeeNumber <dbl> 1, 2, 4, 5, 7, 8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 21, 22…
## $ EnvironmentSatisfaction <dbl> 2, 3, 4, 4, 1, 4, 3, 4, 4, 3, 1, 4, 1, 2, 3, 2, 1, 4, 1, 4, 1, 3…
## $ Gender <fct> Female, Male, Male, Female, Male, Male, Female, Male, Male, Male…
## $ HourlyRate <dbl> 94, 61, 92, 56, 40, 79, 81, 67, 44, 94, 84, 49, 31, 93, 50, 51, …
## $ JobInvolvement <dbl> 3, 2, 2, 3, 3, 3, 4, 3, 2, 3, 4, 2, 3, 3, 2, 4, 4, 4, 2, 3, 4, 2…
## $ JobLevel <dbl> 2, 2, 1, 1, 1, 1, 1, 1, 3, 2, 1, 2, 1, 1, 1, 3, 1, 1, 4, 1, 2, 1…
## $ JobRole <fct> Sales Executive, Research Scientist, Laboratory Technician, Rese…
## $ JobSatisfaction <dbl> 4, 2, 3, 3, 2, 4, 1, 3, 3, 3, 2, 3, 3, 4, 3, 1, 2, 4, 4, 4, 3, 1…
## $ MaritalStatus <fct> Single, Married, Single, Married, Married, Single, Married, Divo…
## $ MonthlyIncome <dbl> 5993, 5130, 2090, 2909, 3468, 3068, 2670, 2693, 9526, 5237, 2426…
## $ MonthlyRate <dbl> 19479, 24907, 2396, 23159, 16632, 11864, 9964, 13335, 8787, 1657…
## $ NumCompaniesWorked <dbl> 8, 1, 6, 1, 9, 0, 4, 1, 0, 6, 0, 0, 1, 0, 5, 1, 0, 1, 2, 5, 0, 7…
## $ Over18 <fct> Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y, Y…
## $ OverTime <fct> Yes, No, Yes, Yes, No, No, Yes, No, No, No, No, Yes, No, No, Yes…
## $ PercentSalaryHike <dbl> 11, 23, 15, 11, 12, 13, 20, 22, 21, 13, 13, 12, 17, 11, 14, 11, …
## $ PerformanceRating <dbl> 3, 4, 3, 3, 3, 3, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4…
## $ RelationshipSatisfaction <dbl> 1, 4, 2, 3, 4, 3, 1, 2, 2, 2, 3, 4, 4, 3, 2, 3, 4, 2, 3, 3, 4, 2…
## $ StandardHours <dbl> 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, 80, …
## $ StockOptionLevel <dbl> 0, 1, 0, 0, 1, 0, 3, 1, 0, 2, 1, 0, 1, 1, 0, 1, 2, 2, 0, 0, 1, 0…
## $ TotalWorkingYears <dbl> 8, 10, 7, 8, 6, 8, 12, 1, 10, 17, 6, 10, 5, 3, 6, 10, 7, 1, 31, …
## $ TrainingTimesLastYear <dbl> 0, 3, 3, 3, 3, 2, 3, 2, 2, 3, 5, 3, 1, 2, 4, 1, 5, 2, 3, 3, 5, 4…
## $ WorkLifeBalance <dbl> 1, 3, 3, 3, 3, 2, 2, 3, 3, 2, 3, 3, 2, 3, 3, 3, 2, 2, 3, 3, 2, 3…
## $ YearsAtCompany <dbl> 6, 10, 0, 8, 2, 7, 1, 1, 9, 7, 5, 9, 5, 2, 4, 10, 6, 1, 25, 3, 4…
## $ YearsInCurrentRole <dbl> 4, 7, 0, 7, 2, 7, 0, 0, 7, 7, 4, 5, 2, 2, 2, 9, 2, 0, 8, 2, 2, 3…
## $ YearsSinceLastPromotion <dbl> 0, 1, 0, 3, 2, 3, 0, 0, 1, 7, 0, 0, 4, 1, 0, 8, 0, 0, 3, 1, 1, 0…
## $ YearsWithCurrManager <dbl> 5, 7, 0, 0, 2, 6, 0, 0, 8, 7, 3, 8, 3, 2, 3, 8, 5, 0, 7, 2, 3, 3…
Let’s take a quick look at the number of positive and negative cases to ensure we don’t need to investigate sampling methods such as down/up sampling or synthetic sampling.
case_table <- employee_attrition %>%
dplyr::count(Attrition, name = "Count") %>%
dplyr::mutate(prop = Count / sum(Count))
case_table
## # A tibble: 2 x 3
## Attrition Count prop
## <fct> <int> <dbl>
## 1 No 1233 0.839
## 2 Yes 237 0.161
We do see we have some moderate observation imbalance. To keep consistent with the analysis performed in the article we’ll leave as is and note this for future analysis.
Although, we’re letting H20 decide the optimal model for us let’s still check we have enough observations per variable under a logistic regression model.
Per Peduzzi et al. N = 10 * k / p, where N is the number of observations, k is the number of features, and p is the number of cases.
k <- length(colnames(employee_attrition[, -1]))
p <- case_table %>%
{`[[`(., 3)[2]}
N <- 10 * k / p
N
## [1] 2108.861
We see no issues with the number of features to observations.
Here we’ll setup our model and establish training, test, and validation sets.
h2o::h2o.init()
##
## H2O is not running yet, starting it now...
##
## Note: In case of errors look at the following log files:
## /tmp/RtmpUmXXBs/h2o_ben_started_from_r.out
## /tmp/RtmpUmXXBs/h2o_ben_started_from_r.err
##
##
## Starting H2O JVM and connecting: . Connection successful!
##
## R is connected to the H2O cluster:
## H2O cluster uptime: 1 seconds 213 milliseconds
## H2O cluster timezone: Etc/UTC
## H2O data parsing timezone: UTC
## H2O cluster version: 3.28.0.1
## H2O cluster version age: 1 month and 20 days
## H2O cluster name: H2O_started_from_R_ben_ugd504
## H2O cluster total nodes: 1
## H2O cluster total memory: 6.96 GB
## H2O cluster total cores: 4
## H2O cluster allowed cores: 4
## H2O cluster healthy: TRUE
## H2O Connection ip: localhost
## H2O Connection port: 54321
## H2O Connection proxy: NA
## H2O Internal Security: FALSE
## H2O API Extensions: Amazon S3, XGBoost, Algos, AutoML, Core V3, TargetEncoder, Core V4
## R Version: R version 3.6.2 (2019-12-12)
h2o::h2o.no_progress()
employee_attrition_h2o <- h2o::as.h2o(employee_attrition)
split_h2o <- h2o::h2o.splitFrame(employee_attrition_h2o, c(0.7, 0.15), seed = 1234)
train_h2o <- h2o::h2o.assign(split_h2o[[1]], "train")
valid_h2o <- h2o::h2o.assign(split_h2o[[2]], "valid")
test_h2o <- h2o::h2o.assign(split_h2o[[3]], "test")
y <- "Attrition"
x <- setdiff(names(train_h2o), y)
Time for the modeling step!
automl_models_h2o <- h2o::h2o.automl(
x = x,
y = y,
training_frame = train_h2o,
leaderboard_frame = valid_h2o,
max_runtime_secs = 30,
seed = 1234
)
automl_leader <- automl_models_h2o@leader
Let’s evaluate the model’s performance.
One interesting thing to note is at this step we’re not even aware of the model H2O has selected. That’s an interesting feature of the AutoML; we’re focused on the results instead of selecting features and optimizing parameters at this stage.
pred_h2o <- h2o::h2o.predict(object = automl_leader, newdata = test_h2o)
test_performance <- test_h2o %>%
tibble::as_tibble() %>%
dplyr::select(Attrition) %>%
dplyr::mutate(pred = as.vector(pred_h2o$predict)) %>%
dplyr::mutate_if(is.character, as.factor)
test_performance
## # A tibble: 211 x 2
## Attrition pred
## <fct> <fct>
## 1 No No
## 2 No No
## 3 Yes Yes
## 4 No No
## 5 No No
## 6 No No
## 7 Yes Yes
## 8 No No
## 9 No No
## 10 Yes Yes
## # … with 201 more rows
confusion_matrix <- test_performance %>%
table()
confusion_matrix
## pred
## Attrition No Yes
## No 170 12
## Yes 13 16
Here I liked the authors pause to point out the high null error rate. You could pick no and have an accuracy of ~77%. Having a ~10% between a naive no model and your actual model isn’t great. However, the author goes on to point out recall’s value to HR. The organization would prefer to missclassify employees as high risk when they’re not versus missclassify as not high risk when they are. Too often the focus on modeling is on accuracy and misses meaningfulness to the organization. While our numbers differ slightly from those on the article, the organization could possible keep 72% of employees predicted as high risk.
tn <- confusion_matrix[1]
tp <- confusion_matrix[4]
fp <- confusion_matrix[3]
fn <- confusion_matrix[2]
accuracy <- (tp + tn) / (tp + tn + fp + fn)
misclassification_rate <- 1 - accuracy
recall <- tp / (tp + fn)
precision <- tp / (tp + fp)
null_error_rate <- tn / (tp + tn + fp + fn)
tibble::tibble(
accuracy,
misclassification_rate,
recall,
precision,
null_error_rate
) %>%
t()
## [,1]
## accuracy 0.8815166
## misclassification_rate 0.1184834
## recall 0.5517241
## precision 0.5714286
## null_error_rate 0.8056872
automl_leader
## Model Details:
## ==============
##
## H2OBinomialModel: stackedensemble
## Model ID: StackedEnsemble_BestOfFamily_AutoML_20200206_044731
## NULL
##
##
## H2OBinomialMetrics: stackedensemble
## ** Reported on training data. **
##
## MSE: 0.02772933
## RMSE: 0.1665212
## LogLoss: 0.1215058
## Mean Per-Class Error: 0.04749831
## AUC: 0.9964706
## AUCPR: 0.958869
## Gini: 0.9929412
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## No Yes Error Rate
## No 859 11 0.012644 =11/870
## Yes 14 156 0.082353 =14/170
## Totals 873 167 0.024038 =25/1040
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.302228 0.925816 137
## 2 max f2 0.199425 0.957207 172
## 3 max f0point5 0.417172 0.944730 123
## 4 max accuracy 0.302228 0.975962 137
## 5 max precision 0.992197 1.000000 0
## 6 max recall 0.199425 1.000000 172
## 7 max specificity 0.992197 1.000000 0
## 8 max absolute_mcc 0.302228 0.911526 137
## 9 max min_per_class_accuracy 0.234473 0.970588 156
## 10 max mean_per_class_accuracy 0.199425 0.978161 172
## 11 max tns 0.992197 870.000000 0
## 12 max fns 0.992197 166.000000 0
## 13 max fps 0.026594 870.000000 399
## 14 max tps 0.199425 170.000000 172
## 15 max tnr 0.992197 1.000000 0
## 16 max fnr 0.992197 0.976471 0
## 17 max fpr 0.026594 1.000000 399
## 18 max tpr 0.199425 1.000000 172
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
##
## H2OBinomialMetrics: stackedensemble
## ** Reported on cross-validation data. **
## ** 5-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
##
## MSE: 0.08896314
## RMSE: 0.2982669
## LogLoss: 0.3077805
## Mean Per-Class Error: 0.2524003
## AUC: 0.8320047
## AUCPR: 0.6496573
## Gini: 0.6640095
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## No Yes Error Rate
## No 830 40 0.045977 =40/870
## Yes 78 92 0.458824 =78/170
## Totals 908 132 0.113462 =118/1040
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.394694 0.609272 119
## 2 max f2 0.078752 0.639894 294
## 3 max f0point5 0.608480 0.685654 68
## 4 max accuracy 0.608480 0.888462 68
## 5 max precision 0.989688 1.000000 0
## 6 max recall 0.029366 1.000000 393
## 7 max specificity 0.989688 1.000000 0
## 8 max absolute_mcc 0.394694 0.550091 119
## 9 max min_per_class_accuracy 0.115575 0.756322 250
## 10 max mean_per_class_accuracy 0.209415 0.776099 182
## 11 max tns 0.989688 870.000000 0
## 12 max fns 0.989688 169.000000 0
## 13 max fps 0.025840 870.000000 399
## 14 max tps 0.029366 170.000000 393
## 15 max tnr 0.989688 1.000000 0
## 16 max fnr 0.989688 0.994118 0
## 17 max fpr 0.025840 1.000000 399
## 18 max tpr 0.029366 1.000000 393
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
Of note, for the next section and unsurprisingly, we’ll use the family, link, and regularization info.
Since we noticed above the model H2O has selected is a glm model with ridge regression, we’ll attempt to replicate the results using the GLMNET R package. And since we’re already doing a comparison, why not a few different types of regression, lasso, ridge, and elasticnet.
Here we’ll split our sample according to the same test and training sets we used in our H2O model.
I’ll drop the feature Over18 since it’s a constant factor and doesn’t add any additional information to the outcome.
x_train <- rbind(as.data.frame(train_h2o[x][, -c(21)]), as.data.frame(valid_h2o[x][, -c(21)]))
y_train <- rbind(as.data.frame(train_h2o[y]), as.data.frame(valid_h2o[y]))$Attrition
x_test <- as.data.frame(test_h2o[x][, -c(21)])
y_test <- as.data.frame(test_h2o[y])$Attrition
x_train <- model.matrix(~.-1, x_train)
x_test <- model.matrix(~.-1, x_test)
Since it’s not apparent the value of alpha, although, we do have a hint (Ridge Regression), we’ll perform a grid search over alpha.
alpha_grid <- seq(0, 1, .1)
model_fits <- lapply(alpha_grid, function(alpha){
set.seed(1234)
glmnet::cv.glmnet(x_train, as.factor(y_train),
family = "binomial",
alpha = alpha,
type.measure = "mse",
nfolds = 20
)
})
mse_lambda_min <- lapply(model_fits, function(model){
lambda_min_mse <- model$cvm[which(model$lambda == model$lambda.min) ]
lambda_1se_mse <- model$cvm[which(model$lambda == model$lambda.1se) ]
lambda_min <- model$lambda.min
lambda_1se <- model$lambda.1se
data.frame(lambda_min_mse = lambda_min_mse, lambda_min = lambda_min, lambda_1se_mse = lambda_1se_mse, lambda_1se = lambda_1se)
})
mse_lambda_min <- do.call(rbind, mse_lambda_min)
mse_lambda_min$alpha <- alpha_grid
mse_lambda_min
## lambda_min_mse lambda_min lambda_1se_mse lambda_1se alpha
## 1 0.1910575 0.008881137 0.1994429 0.052017005 0.0
## 2 0.1910322 0.007724360 0.1992090 0.037560485 0.1
## 3 0.1910282 0.006749269 0.2000218 0.029903457 0.2
## 4 0.1910617 0.005948091 0.2002854 0.024012547 0.3
## 5 0.1910651 0.005897270 0.2002692 0.019765305 0.4
## 6 0.1910490 0.005177797 0.1991611 0.015812244 0.5
## 7 0.1910771 0.004314831 0.2000434 0.014461599 0.6
## 8 0.1911163 0.004059019 0.1994610 0.012395656 0.7
## 9 0.1911311 0.003551641 0.1990230 0.010846199 0.8
## 10 0.1911558 0.003157014 0.2004191 0.010581058 0.9
## 11 0.1911867 0.002841313 0.2001205 0.009522952 1.0
Interestingly, enough we do not see our value of lambda. There’s a couple of reasons for this:
Regardless, we’ll pick the lambda with the lowest mse for both lambda_min and lambda_1se and see how they compare.
glmnet_min <- glmnet::glmnet(x_train, as.factor(y_train),
family = "binomial",
alpha = 0.2,
lambda = 0.005412513
)
glmnet_1se <- glmnet::glmnet(x_train, as.factor(y_train),
family = "binomial",
alpha = 0.4,
lambda = 0.02299651
)
glmnet_min_pred <- predict(glmnet_min,
s = "lambda.min",
newx = x_test,
type = "class"
)
glmnet_1se_pred <- predict(glmnet_1se,
s = "lambda.1se",
newx = x_test,
type = "class"
)
To make measure performance for our 2 models, we’ll make it easy and create a performance function.
perfomance <- function(confusion_matrix) {
tn <- confusion_matrix[1]
tp <- confusion_matrix[4]
fp <- confusion_matrix[3]
fn <- confusion_matrix[2]
accuracy <- (tp + tn) / (tp + tn + fp + fn)
misclassification_rate <- 1 - accuracy
recall <- tp / (tp + fn)
precision <- tp / (tp + fp)
null_error_rate <- tn / (tp + tn + fp + fn)
tibble::tibble(
accuracy,
misclassification_rate,
recall,
precision,
null_error_rate
) %>%
t()
}
A quick look at our confusion matrix highlights a mixed recall.
glmnet_min_confusion <- table(y_test, glmnet_min_pred, dnn = c("Attrition", "Prediction"))
glmnet_min_confusion
## Prediction
## Attrition No Yes
## No 173 9
## Yes 14 15
Compared to our H2O model, the accuracy is slightly higher. However, our recall is lower. The gap between our null error and accuracy has closed slightly.
perfomance(glmnet_min_confusion)
## [,1]
## accuracy 0.8909953
## misclassification_rate 0.1090047
## recall 0.5172414
## precision 0.6250000
## null_error_rate 0.8199052
Our 1se model has likewise traded in recall for precision. As you can see from the confusion matrix the model is very certain when it predicts yes it’s correct. However, we’ve closed the gap even more between our null error rate and accuracy.
glmnet_1se_confusion <- table(y_test, glmnet_1se_pred, dnn = c("Attrition", "Prediction"))
glmnet_1se_confusion
## Prediction
## Attrition No Yes
## No 181 1
## Yes 21 8
perfomance(glmnet_1se_confusion)
## [,1]
## accuracy 0.8957346
## misclassification_rate 0.1042654
## recall 0.2758621
## precision 0.8888889
## null_error_rate 0.8578199
Overall, the GLMNET model has similar results to the H2O model, but it does tend to focus more on accuracy and precision at the expense of recall. It’s important to keep these things in mind when developing a model and ensuring you’re optimizing for the task at hand.
We’ll return now to the original article.
Here we’ll use the lime package to feature importance plots.
Just to note an appealing aspect of lime is it being model agnostic. We can apply the same technique regardless of it being random forest or neural network.
Since lime is still only a few years old I’ll give a brief explanation of what lime does. The first step is taking one of our predicted observations and creating slight perturbations of the original observation and finding the predictions using the original model. Finally, it fits a linear model based on the local characteristics of the model.
require('lime')
model_type.H2OBinomialModel <- function(x, ...) return('classification')
predict_model.H2OBinomialModel <- function(x, newdata, type, ...) {
pred <- h2o::h2o.predict(x, h2o::as.h2o(newdata))
return(as.data.frame(pred[, -1]))
}
predict_model(x = automl_leader, newdata = as.data.frame(test_h2o[, -1]), type = 'raw') %>%
tibble::as_tibble()
## # A tibble: 211 x 2
## No Yes
## <dbl> <dbl>
## 1 0.858 0.142
## 2 0.957 0.0426
## 3 0.0563 0.944
## 4 0.966 0.0341
## 5 0.890 0.110
## 6 0.956 0.0442
## 7 0.138 0.862
## 8 0.939 0.0612
## 9 0.909 0.0912
## 10 0.456 0.544
## # … with 201 more rows
explainer <- lime::lime(
as.data.frame(train_h2o[, -1]),
model = automl_leader,
bin_continuous = FALSE
)
explanation <- lime::explain(
as.data.frame(test_h2o[1:10, -1]),
explainer = explainer,
n_labels = 1,
n_features = 4,
kernel_width = 0.5
)
Let’s take a look at the feature importance plots from lime.
Note these differ quite a bit from those found in the article. This highlights, an important point about reproducibilty. Since the authors didn’t disclose or set a seed value for the AutoML we cannot reproduce their results; it’s been left up to chance. As you can see below the feature importance gives a very different perspective on the possible interventions of attrition.
lime::plot_features(explanation) +
ggplot2::labs(title = "Employee Attrition Prediction: Feature Importance",
subtitle = "First 10 Cases From Test Set")
From the feature plots we generated, a few things stand out as important features in which we can explore and possible create intervention plans.
attrition_critical_features <- employee_attrition %>%
tibble::as_tibble() %>%
dplyr::select(Attrition, NumCompaniesWorked, YearsSinceLastPromotion, OverTime) %>%
dplyr::mutate(Case = dplyr::row_number()) %>%
dplyr::select(Case, dplyr::everything())
attrition_critical_features
## # A tibble: 1,470 x 5
## Case Attrition NumCompaniesWorked YearsSinceLastPromotion OverTime
## <int> <fct> <dbl> <dbl> <fct>
## 1 1 Yes 8 0 Yes
## 2 2 No 1 1 No
## 3 3 Yes 6 0 Yes
## 4 4 No 1 3 Yes
## 5 5 No 9 2 No
## 6 6 No 0 3 No
## 7 7 No 4 0 Yes
## 8 8 No 1 0 No
## 9 9 No 0 1 No
## 10 10 No 6 7 No
## # … with 1,460 more rows
We can see from the rankings that in general the more companies worked at the more inclined it is that the individual would leave. Oddly, I would have assumed this to be monotonic and capture job hoppers. As an intervention, I’m not sure what you could possible do; you can’t do much about a person’s previous job history. As a prehire check this could be useful information if you’re looking for a long term employee.
attrition_critical_features %>%
dplyr::group_by(NumCompaniesWorked, Attrition) %>%
dplyr::summarise(n = dplyr::n()) %>%
dplyr::mutate(freq = n / sum(n)) %>%
dplyr::filter(Attrition == 'Yes') %>%
ggplot2::ggplot(ggplot2::aes(forcats::fct_reorder(as.factor(NumCompaniesWorked), freq), freq)) +
ggplot2::geom_bar(stat = 'identity') +
ggplot2::coord_flip() +
ggplot2::ylim(0, 1) +
ggplot2::ylab("Attrition Percentage (Yes / Total)") +
ggplot2::xlab("NumCompaniesWorked")
No surprises here. A large number of employees who stay are not working overtime. Interventions here could include hiring more staff or offering incentives for those times when overtime is necessary.
attrition_critical_features %>%
dplyr::mutate(OverTime = ifelse(OverTime == 'Yes', 1, 0)) %>%
ggplot2::ggplot(ggplot2::aes(Attrition, OverTime)) +
ggplot2::geom_violin(trim=TRUE) +
ggplot2::geom_jitter(shape = 16, position = ggplot2::position_jitter(0.4))
This is also interestingly not monotonic. It’s clear at the top that employees leave if they haven’t been promoted for a long time.
Possible interventions here would include promotions or in lieu of that recognition programs.
attrition_critical_features %>%
dplyr::group_by(YearsSinceLastPromotion, Attrition) %>%
dplyr::summarise(n = dplyr::n()) %>%
dplyr::mutate(freq = n / sum(n)) %>%
dplyr::filter(Attrition == 'Yes') %>%
ggplot2::ggplot(ggplot2::aes(forcats::fct_reorder(as.factor(YearsSinceLastPromotion), freq), freq)) +
ggplot2::geom_bar(stat = 'identity') +
ggplot2::coord_flip() +
ggplot2::ylim(0, 1) +
ggplot2::ylab("Attrition Percentage (Yes / Total)") +
ggplot2::xlab("YearsSinceLastPromotion")
Working through this example has been interesting. Not only were we able to explore interesting employee data, but also see how models compare and the pitfalls of not ensuring your work is reproducible.