library(iml)
library(ranger) # different random Forest package!
library(EcoData)
library(cito)
set.seed(123)
= EcoData::elephant$occurenceData
data head(data)
::elephant ?EcoData
8 Explainable AI
The goal of explainable AI (xAI, aka interpretable machine learning) is to explain why a fitted machine learning model makes certain predictions. A typical example is to understand how important different variables are for predictions. The incentives for doing so range from a better technical understanding of the models over understanding which data is important for improving predictions to questions of fairness and discrimination (e.g. to understand if an algorithm uses skin color to make a decision).
8.1 A Practical Example
In this lecture we will work with an African Elephant occurrence dataset.
We will fit a random forest and use the iml package for xAI, see https://christophm.github.io/interpretable-ml-book/.
Meaning of the bioclim variables:
Bioclim variable | Meaning |
---|---|
bio1 | Annual Mean Temperature |
bio2 | Mean Diurnal Range (Mean of monthly (max temp - min temp)) |
bio3 | Isothermality (BIO2/BIO7) (×100) |
bio4 | Temperature Seasonality (standard deviation ×100) |
bio5 | Max Temperature of Warmest Month |
bio6 | Min Temperature of Coldest Month |
bio7 | Temperature Annual Range (BIO5-BIO6) |
bio8 | Mean Temperature of Wettest Quarter |
bio9 | Mean Temperature of Driest Quarter |
bio10 | Mean Temperature of Warmest Quarter |
bio11 | Mean Temperature of Coldest Quarter |
bio12 | Annual Precipitation |
bio13 | Precipitation of Wettest Month |
bio14 | Precipitation of Driest Month |
bio15 | Precipitation Seasonality (Coefficient of Variation) |
bio16 | Precipitation of Wettest Quarter |
bio17 | Precipitation of Driest Quarter |
bio18 | Precipitation of Warmest Quarter |
bio19 | Precipitation of Coldest Quarter |
= ranger(as.factor(Presence) ~ ., data = data, probability = TRUE) rf
The ranger package provides an alternative implementation of the random forest algorithm. The implementation in the ranger package is one of the fastest available, which is especially important for explainable AI. Most xAI tools require hundreds of predictions (e.g. feature importance permutes each feature n times to calculate the performance drop), so a fast implementation of the ML algorithm is crucial.
Important:
For binary classification tasks, it is critical to change the response variable to a factor before fitting ranger! Otherwise, ranger will use the first value in the response variable as “Class 1”!
The cito package has quite extensive xAI functionalities. However, ranger, as most other machine learning packages, has no extensive xAI functionalities. Thus, to do xAI with ranger, we have to use a generic xAI package that can handle almost all machine learning models.
When we want to use such a generic package, we first have to create a predictor object, that holds the model and the data. The iml
package uses R6 classes, that means new objects can be created by calling Predictor$new()
. (Do not worry if you do not know what R6 classes are, just use the command.)
To make the xAI tools available to many different packages/algorithms, the iml
package expects that the ML algorithm specific predict method to be wrapped in a generic predict function in the form of function(model, newdata) predict(model, newdata)
and the function wrapper should return a vector of predictions:
= function(model, newdata) predict(model, data=newdata)$predictions[,2]
predict_wrapper
= Predictor$new(rf, data = data[,-1], y = data[,1], predict.function = predict_wrapper)
predictor $task = "classif" # set task to classification
predictor# "Predictor" is an object generator.
8.2 Feature/Permutation Importance
Feature importance should not be confused with random forest variable importance, although they are related. It tells us how important each variable is for prediction, can be computed for all machine learning models, and is based on a permutation approach (see the book):
= FeatureImp$new(predictor, loss = "ce")
imp plot(imp)
bio9 (Precipitation of the wettest Quarter) is the most important variable.
8.3 Partial Dependencies
Partial dependencies are similar to allEffects
plots for normal regressions. The idea is to visualize “marginal effects” of predictors (with the “feature” argument we specify the variable we want to visualize):
= FeatureEffect$new(predictor, feature = "bio9", method = "pdp",
eff grid.size = 30)
plot(eff)
One disadvantage of partial dependencies is that they are sensitive to correlated predictors. Accumulated local effects can be used for accounting for correlation of predictors.
8.4 Accumulated Local Effects
Accumulated local effects (ALE) are basically partial dependencies plots but try to correct for correlations between predictors.
= FeatureEffect$new(predictor, feature = "bio9", method = "ale")
ale $plot() ale
If there is no collinearity, you shouldn’t see much difference between partial dependencies and ALE plots.
8.5 Friedman’s H-statistic
The H-statistic can be used to find interactions between predictors. However, again, keep in mind that the H-statistic is sensible to correlation between predictors:
= Interaction$new(predictor, "bio9",grid.size = 5L)
interact plot(interact)
8.6 Global Explainer - Simplifying the Machine Learning Model
Another idea is simplifying the machine learning model with another simpler model such as a decision tree. We create predictions with the machine learning model for a lot of different input values and then we fit a decision tree on these predictions. We can then interpret the easier model.
library(partykit)
= TreeSurrogate$new(predictor, maxdepth = 2)
tree plot(tree$tree)
8.7 Local Explainer - LIME Explaining Single Instances (observations)
The global approach is to simplify the entire machine learning-black-box model via a simpler model, which is then interpretable.
However, sometimes we are only interested in understanding how single predictions are generated. The LIME (Local interpretable model-agnostic explanations) approach explores the feature space around one observation and based on this locally fits a simpler model (e.g. a linear model):
= LocalModel$new(predictor, x.interest = data[1,-1])
lime.explain $results
lime.explainplot(lime.explain)
8.8 Local Explainer - Shapley
The Shapley method computes the so called Shapley value, feature contributions for single predictions, and is based on an approach from cooperative game theory. The idea is that each feature value of the instance is a “player” in a game, where the prediction is the reward. The Shapley value tells us how to fairly distribute the reward among the features.
= Shapley$new(predictor, x.interest = data[1,-1])
shapley $plot() shapley
8.9 Uncertainties - the bootstrap
Standard xAI method do not provide reliable uncertainties on the fitted curves. If you want uncertainties or p-values, the most common method is the bootstrap.
In a bootstrap, is instead of splitting up the data in test / validation, we sample from the data with replacement and fit the models repeatedly. The idea is to get an estimate about the variability we would expect if we created another dataset of the same size.
= 10 # bootstrap samples
k = nrow(data)
n = rep(NA, k)
error
for(i in 1:k){
= sample.int(n, n, replace = TRUE)
bootSample = ranger(as.factor(Presence) ~ ., data = data[bootSample,], probability = TRUE)
rf = rf$prediction.error
error[i]
}
hist(error, main = "uncertainty of in-sample error")
Note that the distinction between bootstrap and validation / cross-validation is as follows:
- Validation / cross-validation estimates out-of-sample predictive error
- Bootstrap estimates uncertainty / confidence interval on all model outputs (could be prediction and inference).
8.10 Exercises
Data preparation
library(iml)
library(cito)
library(EcoData)
library(cito)
= EcoData::elephant$occurenceData
data head(data)
::elephant
?EcoData
# we will subsample data (absences) to reduce runtime
= data[sample.int(nrow(data), 500),] data_sub
Cito includes serveral xAI methods directly out of the box
= dnn(Presence~., data = data_sub, batchsize = 200L,loss = "binomial", verbose = FALSE, lr = 0.15, epochs = 300) model
Try the following commands:
summary(dnn, n_permute = 10)
PDP(dnn)
ALE(dnn)
Moreover, try to refit the model with the option bootstrap = 5
. This may take a short while. Observe how the xAI options change.
= dnn(Presence~., data = data_sub, batchsize = 200L, bootstrap = 5L, loss = "binomial", verbose = FALSE, lr = 0.15, epochs = 300) model
summary(model, n_permute = 10L)
PDP(model)
ALE(model)
Use the Titanic_ml dataset and fit a random forest, dnn or a BRT using xgboost. Explore / interpret the fitted model using iml
(see also the book: https://christophm.github.io/interpretable-ml-book/).
Tip:
If you use iml, you need to provide a proper prediction function wrapper:
# random Forest (ranger), regression:
= function(model, newdata) predict(model, data=newdata)$predictions
predict_wrapper
# random Forest (ranger), classification:
= function(model, newdata) predict(model, data=newdata)$predictions[,2]
predict_wrapper
# xgboost:
= function(model, newdata) predict(model, as.matrix(newdata)) predict_wrapper
Prepare the data
library(EcoData)
library(dplyr)
library(missRanger) # for imputation
= titanic_ml
data
# feature selection
= data %>% select(survived, sex, age, fare, pclass)# play around with the features
data
# imputation - remove response variable!
head(data)
= data
data_imputed -1] = missRanger(data_imputed[,-1])
data_imputed[,summary(data_imputed)
= data_imputed %>%
data_imputed mutate(age = (age - mean(age))/sd(age), fare = (fare - mean(fare))/sd(fare),
sex = as.integer(sex), pclass = as.integer(pclass))
= data_imputed[!is.na(data_imputed$survived), ]
data_obs = data_imputed[is.na(data_imputed$survived), ] data_new
library(ranger)
library("iml")
set.seed(1234)
$survived = as.factor(data_obs$survived)
data_obs
= ranger(survived ~ ., data = data_obs, importance = "impurity", probability = TRUE)
rf
# For submission:
#write.csv(data.frame(y=predict(rf, data_new)$predictions[,2]), file = "wine_RF.csv")
# Standard depiction of importance:
::importance(rf)
ranger
# Setup wrapper
= function(model, newdata) predict(model, data=newdata)$predictions[,2]
predict_wrapper
# IML:
= Predictor$new(
predictor data = data_obs[,which(names(data_obs) != "survived")], y = as.integer(data_obs$survived)-1,
rf, predict.function = predict_wrapper
)
# Mind: This is stochastical!
= FeatureImp$new(predictor, loss = "logLoss")
importance
plot(importance)
# Comparison between standard importance and IML importance:
= names(rf$variable.importance)[order(rf$variable.importance, decreasing = TRUE)]
importanceRf = importance$results[1]
importanceIML = cbind(importanceIML, importanceRf)
comparison colnames(comparison) = c("IML", "RF")
as.matrix(comparison)
Mind that feature importance, and the random forest’s variable importance are related but not equal! Variable importance is a measure for determining importance while creating the forest (i.e. for fitting). Feature importance is a measure for how important a variable is for prediction.
Maybe you want to see other explanation methods as well. Surely you can use the other techniques of this section on your own.
library(xgboost)
library("iml")
set.seed(1234)
= xgb.DMatrix(
data_xg data = as.matrix(data_obs[,which(names(data_obs) != "survived")]),
label = as.integer(data_obs$survived)-1
)= xgboost(data_xg, nrounds = 24, objective = "reg:logistic")
brt
# For submission:
#write.csv(round(predict(brt, data_new)), file = "wine_RF.csv")
# Standard depiction of importance:
::xgb.importance(model = brt)
xgboost
# Setup wrapper
= function(model, newdata) predict(model, as.matrix(newdata))
predict_wrapper
# IML:
= Predictor$new(
predictor data = data_obs[,which(names(data_obs) != "survived")], y = as.integer(data_obs$survived)-1,
brt, predict.function = predict_wrapper
)
# Mind: This is stochastical!
= FeatureImp$new(predictor, loss = "logLoss")
importance
plot(importance)
library(cito)
$survived = as.integer(data_obs$survived) - 1
data_obs= dnn(survived~., data = data_obs, loss = "binomial", lr= 0.03, epochs = 300)
nn
summary(nn)