nestedcv
provides two methods for understanding fitted
models. The simplest of these is to plot variable importance. The newer
method is to calculate Shapley values for each predictor.
For regression model systems such as glmnet variable importance is represented as the coefficients of the model scaled by absolute value from largest to smallest. However, the outer folds of nested CV allow us to show the variance of model coefficients across each outer fold plus the final model, and hence see how stable the model is. We can also overlay how often predictors are selected in each model to give a sense of the stability of predictor selection.
In the example below using the Boston housing dataset, a glmnet
regression model is fitted and the variable importance for predictors is
shown based on the coefficients in the final model and tuned models from
10 outer folds. min_1se
is set to 1, which is the
equivalent of specifying s = "lambda.1se"
with
glmnet
, to encourage a more sparse model.
library(nestedcv)
library(mlbench) # Boston housing dataset
data(BostonHousing2)
dat <- BostonHousing2
y <- dat$cmedv
x <- subset(dat, select = -c(cmedv, medv, town, chas))
# Fit a glmnet model using nested CV
set.seed(1, "L'Ecuyer-CMRG")
fit <- nestcv.glmnet(y, x, family = "gaussian",
min_1se = 1, alphaSet = 1, cv.cores = 2)
vs <- var_stability(fit)
vs
## mean sd sem frequency sign direction final
## lon 48.393372145 13.359312172 4.027984176 11 -1 negative yes
## rm 35.277524448 12.364884451 3.728152936 11 1 positive yes
## ptratio 5.660722405 1.767597466 0.532950689 11 -1 negative yes
## lstat 4.460941717 1.651898823 0.498066235 11 -1 negative yes
## nox 3.632481722 5.345960957 1.611867876 4 -1 negative yes
## dis 1.266762305 1.029387844 0.310372113 8 -1 negative yes
## lat 1.135429569 3.084782762 0.930096998 2 1 positive no
## crim 0.120941548 0.085181803 0.025683280 9 -1 negative yes
## b 0.044978058 0.009116406 0.002748700 11 1 positive yes
## tax 0.005062703 0.005459025 0.001645958 8 -1 negative yes
## zn 0.001783381 0.005914805 0.001783381 1 1 positive no
Variable stability can be plotted using
plot_var_stability()
.
p1 <- plot_var_stability(fit)
# overlay directionality using colour
p2 <- plot_var_stability(fit, final = FALSE, direction = 1)
# or show directionality with the sign of the variable importance
# plot_var_stability(fit, final = FALSE, percent = F)
ggpubr::ggarrange(p1, p2, ncol=2)
By default only the predictors chosen in the final model are shown.
If the argument final
is set to FALSE
all
predictors are shown to help understand how often they are selected
which is helpful when pushing sparsity in models. Original coefficients
can be shown instead of being scaled as a percentage by setting
percent = FALSE
.
The frequency with which each variable is selected in outer folds as well as the final model is shown by as bubble size, which is helpful for sparse models or with filters to determine how often variables end up in the model in each fold.
We can overlay directionality using either colour
(direction = 1
) or the sign of the variable importance to
splay the plot (direction = 2
).
For glmnet, the direction of effect is taken directly from the sign
of model coefficients. For caret
models, direction of
effect is not readily available, so as a substitute, the directionality
of each predictor is determined by the function
var_direction()
using the sign of a t-test for
binary classification or the sign of regression coefficient for
continuous outcomes (not available for multiclass caret models). To
better understand relationship and direction of effect of each predictor
within the final model, we recommend using SHAP values.
Alternatively a barplot is available using
barplot_var_stability()
.
The caret
package allows variable importance to be
calculated for other models types. This is model dependent. For example,
with random forest variable importance is usually calculated as the mean
decrease in Gini impurity each time a variable is chosen in a tree. With
caret models, you may have to load the appropriate package for that
model to calculate the variable importance, e.g. this is necessary for
GBM models with the gbm
package.
To change the colour scheme for direction = 1
overwrite
scale_fill_manual()
.
# change bubble colour scheme
p1 + scale_fill_manual(values=c("orange", "green3"))
The original implementation of shap by Scott Lundberg is a
python package. We suggest using the R package fastshap for
examining nestedcv
models since it works with
classification or regression as well as any model type (regression such
as glmnet or tree based such as random forest, GBM or xgboost).
In the example below using the same glmnet regression model, the
variable importance for predictors is measured using Shapley values. The
function explain()
from the fastshap
package
needs a wrapper function for prediction using the model.
nestedcv
provides pred_nestcv_glmnet
which is
a wrapper function for binary classification or regression with
nestcv.glmnet
fitted models.
library(fastshap)
# Generate SHAP values using fastshap::explain
# Only using 5 repeats here for speed, but recommend higher values of nsim
sh <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet, nsim = 5)
# Plot overall variable importance
plot_shap_bar(sh, x)
## Variables with mean(|SHAP|)=0: tract, lat, zn, indus, age, rad
plot_shap_bar()
ranks predictors in terms of variable
importance calculated as mean absolute SHAP value. The plot also
overlays colour for the direction of the main effect of each variable on
the model, based on correlating the value of each variable against the
SHAP value to see if the overall correlation for that variable is
positive or negative. For regression models such as glmnet this
corresponds to the sign of each coefficient. For more complex models
with interactions the direction of effect may be variable and
non-linear.
Note that these SHAP plots only show the final fitted model (on the whole data), whereas the variable stability plots examine models across the outer CV folds as well as the final model.
Be careful if you have massive numbers of predictors in
x
(see Troubleshooting).
nestedcv
also provides a quick plotting function
plot_shap_beeswarm
for generating ggplot2 beeswarm plots
similar to those made by the original python shap
package.
# Plot beeswarm plot
plot_shap_beeswarm(sh, x, size = 1)
## Variables with mean(|SHAP|)=0: tract, lat, zn, indus, age, rad
size
can be set to control the size of points.
cex
controls the amount of overlap of the beeswarm.
scheme
controls the colour scheme as a vector of 3 colours.
Alternatively try the shapviz
package.
The process for a caret model fitted using nestedcv
is
similar. Use the pred_train
wrapper when calling
explain()
.
# Only 3 outer folds to speed up process
fit <- nestcv.train(y, x,
method = "gbm",
n_outer_folds = 3, cv.cores = 2)
# Only using 5 repeats here for speed, but recommend higher values of nsim
sh <- explain(fit, X=x, pred_wrapper = pred_train, nsim = 5)
plot_shap_beeswarm(sh, x, size = 1)
For multinomial classification, a wrapper is needed for each class.
For nestcv.glmnet
models, we provide wrappers for the first
3 classes: pred_nestcv_glmnet_class1
,
pred_nestcv_glmnet_class2
etc. They are very simple
functions and easy to extend to other classes.
For nestcv.train()
models, similarly we provide
pred_train_class1
, pred_train_class2
etc for
the first 3 classes. Again these are easily extended.
As a toy example, we show the iris dataset which has 3 classes.
library(ggplot2)
data("iris")
dat <- iris
y <- dat$Species
x <- dat[, 1:4]
# Only 3 outer folds to speed up process
fit <- nestcv.glmnet(y, x, family = "multinomial", n_outer_folds = 3, alphaSet = 0.6)
# SHAP values for each of the 3 classes
sh1 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class1, nsim = 5)
sh2 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class2, nsim = 5)
sh3 <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet_class3, nsim = 5)
s1 <- plot_shap_bar(sh1, x, sort = FALSE) +
ggtitle("Setosa")
s2 <- plot_shap_bar(sh2, x, sort = FALSE) +
ggtitle("Versicolor")
s3 <- plot_shap_bar(sh3, x, sort = FALSE) +
ggtitle("Virginica")
ggpubr::ggarrange(s1, s2, s3, ncol=3, legend = "bottom", common.legend = TRUE)
Or using beeswarm plots.
s1 <- plot_shap_beeswarm(sh1, x, sort = FALSE, cex = 0.7) +
ggtitle("Setosa")
s2 <- plot_shap_beeswarm(sh2, x, sort = FALSE, cex = 0.7) +
ggtitle("Versicolor")
s3 <- plot_shap_beeswarm(sh3, x, sort = FALSE, cex = 0.7) +
ggtitle("Virginica")
ggpubr::ggarrange(s1, s2, s3, ncol=3, legend = "right", common.legend = TRUE)
Binary factors generally should not have any handling problems as
most functions in nestedcv
convert them to 0 and 1, and
directionality can be interpreted. However, multi-level factors with 3
or more levels are not clearly interpretable with SHAP values. We
recommend these are one-hot encoded with dummy variables using the
function one_hot()
. This function accepts dataframes and
can encode multiple factors and character columns, producing a matrix as
a result.
explain()
runs progressively slowly with larger numbers
of input predictors. So if you have a very large number of predictors in
x
it is much better to pass only the subset of predictors
in the final model. These are stored in fitted nestedcv
objects in list element xsub
. Or you can subset
x
using final_vars
. For example:
sh <- explain(fit, X = fit$xsub, pred_wrapper = pred_nestcv_glmnet, nsim = 5)
plot_shap_bar(sh, fit$xsub)
sh <- explain(fit, X = x[, fit$final_vars], pred_wrapper = pred_nestcv_glmnet, nsim = 5)
plot_shap_bar(sh, x[, fit$final_vars])
If you use this package, please cite as:
Lewis MJ, Spiliopoulou A, Goldmann K, Pitzalis C, McKeigue P, Barnes MR (2023). nestedcv: an R package for fast implementation of nested cross-validation with embedded feature selection designed for transcriptomics and high dimensional data. Bioinformatics Advances. https://doi.org/10.1093/bioadv/vbad048
Brandon Greenwell. fastshap: Fast Approximate Shapley Values
Lundberg SM & Lee SI (2017). A Unified Approach to Interpreting Model Predictions. Advances in Neural Information Processing Systems 30; 4768–4777. http://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf
Lundberg SM et al (2020). From Local Explanations to Global Understanding with Explainable AI for Trees. Nat Mach Intell 2020; 2(1): 56-67. https://arxiv.org/abs/1905.04610