8  Explainable AI

The goal of explainable AI (xAI, aka interpretable machine learning) is to explain why a fitted machine learning model makes certain predictions. A typical example is to understand how important different variables are for predictions. The incentives for doing so range from a better technical understanding of the models over understanding which data is important for improving predictions to questions of fairness and discrimination (e.g. to understand if an algorithm uses skin color to make a decision).

8.1 A Practical Example

In this lecture we will work with an African Elephant occurrence dataset.

We will fit a random forest and use the iml package for xAI, see https://christophm.github.io/interpretable-ml-book/.

library(iml)
library(ranger) # different random Forest package!
library(EcoData)
library(cito)
set.seed(123)


data = EcoData::elephant$occurenceData
head(data)
      Presence       bio1       bio2       bio3       bio4        bio5
3364         0 -0.4981747 -0.2738045  0.5368968 -0.5409999 -0.36843571
6268         0  0.6085908 -0.5568352  1.0340686 -1.2492050 -0.11835651
10285        0 -0.7973005  1.4648130 -1.0540532  2.0759423  0.07614953
2247         0  0.6385034  1.3435141 -0.1591439 -0.5107148  1.10425291
9821         0  0.6684160 -0.6781341  0.6363311 -0.9906170  0.15950927
1351         0  0.9675418 -0.6781341 -0.3580126 -0.3748202  0.77081398
            bio6       bio7       bio8       bio9       bio10       bio11
3364   0.2947850 -0.5260099 -1.2253960  0.2494100 -0.64527314 -0.06267842
6268   0.8221087 -0.8938475  0.4233787  0.7746249  0.09168503  0.94419518
10285 -1.5860029  1.6284678  0.2768209 -1.5153122 -0.03648161 -1.44165748
2247  -0.1622288  0.8577603  0.4600181  0.5855475  0.54026827  0.68153250
9821   0.9099960 -0.8062671  0.3867393  0.8586593  0.31597665  0.94419518
1351   0.8748411 -0.3858812  0.3134604  1.0477367  0.98885151  0.94419518
           bio12      bio13       bio14        bio15      bio16      bio17
3364   0.6285371  0.6807958 -0.29703736 -0.008455252  0.7124535 -0.2949994
6268   1.1121516  0.5918442  0.01619202 -0.884507980  0.5607328  0.3506918
10285 -1.2351482 -1.3396742 -0.50585695  0.201797403 -1.3499999 -0.5616980
2247   0.5951165  0.8714061 -0.55806185  0.236839512  1.1012378 -0.5616980
9821   1.1003561  0.5537222  0.59044589 -1.024676416  0.6413344  0.7437213
1351   0.7287986  1.1255533 -0.50585695  0.236839512  1.2956300 -0.4494038
            bio18       bio19
3364  -1.06812752  1.96201807
6268   1.22589281 -0.36205814
10285 -0.42763181 -0.62895735
2247  -0.20541902 -0.58378979
9821   0.06254347 -0.05409751
1351  -0.90473576  2.47939193
?EcoData::elephant

Meaning of the bioclim variables:

Bioclim variable Meaning
bio1 Annual Mean Temperature
bio2 Mean Diurnal Range (Mean of monthly (max temp - min temp))
bio3 Isothermality (BIO2/BIO7) (×100)
bio4 Temperature Seasonality (standard deviation ×100)
bio5 Max Temperature of Warmest Month
bio6 Min Temperature of Coldest Month
bio7 Temperature Annual Range (BIO5-BIO6)
bio8 Mean Temperature of Wettest Quarter
bio9 Mean Temperature of Driest Quarter
bio10 Mean Temperature of Warmest Quarter
bio11 Mean Temperature of Coldest Quarter
bio12 Annual Precipitation
bio13 Precipitation of Wettest Month
bio14 Precipitation of Driest Month
bio15 Precipitation Seasonality (Coefficient of Variation)
bio16 Precipitation of Wettest Quarter
bio17 Precipitation of Driest Quarter
bio18 Precipitation of Warmest Quarter
bio19 Precipitation of Coldest Quarter
rf = ranger(as.factor(Presence) ~ ., data = data, probability = TRUE)

The cito package has quite extensive xAI functionalities. However, ranger, as most other machine learning packages, has no extensive xAI functionalities. Thus, to do xAI with ranger, we have to use a generic xAI package that can handle almost all machine learning models.

When we want to use such a generic package, we first have to create a predictor object, that holds the model and the data. The iml package uses R6 classes, that means new objects can be created by calling Predictor$new(). (Do not worry if you do not know what R6 classes are, just use the command.)

We often have to warp our predict function inside a so called wrapper function so that the output of the predict function fits to iml (iml expects that the predict function returns a vector of predictions:

predict_wrapper = function(model, newdata) predict(model, data=newdata)$predictions[,2]

predictor = Predictor$new(rf, data = data[,-1], y = data[,1], predict.function = predict_wrapper)
predictor$task = "classif" # set task to classification
# "Predictor" is an object generator.

8.2 Feature Importance

Feature importance should not be mistaken with the random forest variable importance though they are related. It tells us how important the individual variables are for predictions, can be calculated for all machine learning models and is based on a permutation approach (have a look at the book):

imp = FeatureImp$new(predictor, loss = "ce")
plot(imp)

bio9 (Precipitation of the wettest Quarter) is the most important variable.

8.3 Partial Dependencies

Partial dependencies are similar to allEffects plots for normal regressions. The idea is to visualize “marginal effects” of predictors (with the “feature” argument we specify the variable we want to visualize):

eff = FeatureEffect$new(predictor, feature = "bio9", method = "pdp",
                        grid.size = 30)
plot(eff)

One disadvantage of partial dependencies is that they are sensitive to correlated predictors. Accumulated local effects can be used for accounting for correlation of predictors.

8.4 Accumulated Local Effects

Accumulated local effects (ALE) are basically partial dependencies plots but try to correct for correlations between predictors.

ale = FeatureEffect$new(predictor, feature = "bio9", method = "ale")
ale$plot()

If there is no collinearity, you shouldn’t see much difference between partial dependencies and ALE plots.

8.5 Friedman’s H-statistic

The H-statistic can be used to find interactions between predictors. However, again, keep in mind that the H-statistic is sensible to correlation between predictors:

interact = Interaction$new(predictor, "bio9",grid.size = 5L)
plot(interact)

8.6 Global Explainer - Simplifying the Machine Learning Model

Another idea is simplifying the machine learning model with another simpler model such as a decision tree. We create predictions with the machine learning model for a lot of different input values and then we fit a decision tree on these predictions. We can then interpret the easier model.

library(partykit)

tree = TreeSurrogate$new(predictor, maxdepth = 2)
plot(tree$tree)

8.7 Local Explainer - LIME Explaining Single Instances (observations)

The global approach is to simplify the entire machine learning-black-box model via a simpler model, which is then interpretable.

However, sometimes we are only interested in understanding how single predictions are generated. The LIME (Local interpretable model-agnostic explanations) approach explores the feature space around one observation and based on this locally fits a simpler model (e.g. a linear model):

lime.explain = LocalModel$new(predictor, x.interest = data[1,-1])
lime.explain$results
             beta x.recoded       effect        x.original feature
bio9  -0.03972318 0.2494100 -0.009907356 0.249409955204759    bio9
bio16 -0.12035200 0.7124535 -0.085745198 0.712453479144842   bio16
                feature.value
bio9   bio9=0.249409955204759
bio16 bio16=0.712453479144842
plot(lime.explain)

8.8 Local Explainer - Shapley

The Shapley method computes the so called Shapley value, feature contributions for single predictions, and is based on an approach from cooperative game theory. The idea is that each feature value of the instance is a “player” in a game, where the prediction is the reward. The Shapley value tells us how to fairly distribute the reward among the features.

shapley = Shapley$new(predictor, x.interest = data[1,-1])
shapley$plot()

8.9 Uncertainties - the bootstrap

Standard xAI method do not provide reliable uncertainties on the fitted curves. If you want uncertainties or p-values, the most common method is the bootstrap.

In a bootstrap, is instead of splitting up the data in test / validation, we sample from the data with replacement and fit the models repeatedly. The idea is to get an estimate about the variability we would expect if we created another dataset of the same size.

k = 10 # bootstrap samples
n = nrow(data)
error = rep(NA, k)

for(i in 1:k){
  bootSample = sample.int(n, n, replace = TRUE)
  rf = ranger(as.factor(Presence) ~ ., data = data[bootSample,], probability = TRUE)
  error[i] = rf$prediction.error
}

hist(error, main = "uncertainty of in-sample error")

Note that the distinction between bootstrap and validation / cross-validation is as follows:

  • Validation / cross-validation estimates out-of-sample predictive error
  • Bootstrap estimates uncertainty / confidence interval on all model outputs (could be prediction and inference).

8.10 Exercises

xAI in cito

Data preparation

library(iml)
library(cito)
library(EcoData)
library(cito)


data = EcoData::elephant$occurenceData
head(data)
      Presence       bio1       bio2       bio3       bio4        bio5
3364         0 -0.4981747 -0.2738045  0.5368968 -0.5409999 -0.36843571
6268         0  0.6085908 -0.5568352  1.0340686 -1.2492050 -0.11835651
10285        0 -0.7973005  1.4648130 -1.0540532  2.0759423  0.07614953
2247         0  0.6385034  1.3435141 -0.1591439 -0.5107148  1.10425291
9821         0  0.6684160 -0.6781341  0.6363311 -0.9906170  0.15950927
1351         0  0.9675418 -0.6781341 -0.3580126 -0.3748202  0.77081398
            bio6       bio7       bio8       bio9       bio10       bio11
3364   0.2947850 -0.5260099 -1.2253960  0.2494100 -0.64527314 -0.06267842
6268   0.8221087 -0.8938475  0.4233787  0.7746249  0.09168503  0.94419518
10285 -1.5860029  1.6284678  0.2768209 -1.5153122 -0.03648161 -1.44165748
2247  -0.1622288  0.8577603  0.4600181  0.5855475  0.54026827  0.68153250
9821   0.9099960 -0.8062671  0.3867393  0.8586593  0.31597665  0.94419518
1351   0.8748411 -0.3858812  0.3134604  1.0477367  0.98885151  0.94419518
           bio12      bio13       bio14        bio15      bio16      bio17
3364   0.6285371  0.6807958 -0.29703736 -0.008455252  0.7124535 -0.2949994
6268   1.1121516  0.5918442  0.01619202 -0.884507980  0.5607328  0.3506918
10285 -1.2351482 -1.3396742 -0.50585695  0.201797403 -1.3499999 -0.5616980
2247   0.5951165  0.8714061 -0.55806185  0.236839512  1.1012378 -0.5616980
9821   1.1003561  0.5537222  0.59044589 -1.024676416  0.6413344  0.7437213
1351   0.7287986  1.1255533 -0.50585695  0.236839512  1.2956300 -0.4494038
            bio18       bio19
3364  -1.06812752  1.96201807
6268   1.22589281 -0.36205814
10285 -0.42763181 -0.62895735
2247  -0.20541902 -0.58378979
9821   0.06254347 -0.05409751
1351  -0.90473576  2.47939193
?EcoData::elephant

# we will subsample data (absences) to reduce runtime
data_sub = data[sample.int(nrow(data), 500),]

Cito includes a lot of xAI methods directly out of the box

model = dnn(Presence~., data = data_sub, batchsize = 200L,loss = "binomial", verbose = FALSE, lr = 0.15, epochs = 300)

Try the following commands:

  • summary(dnn, n_permute = 10)
  • PDP(dnn)
  • ALE(dnn)

Moreover, try to refit the model with the option bootstrap = 5. This may take a short while. Observe how the xAI options change.

model = dnn(Presence~., data = data_sub, batchsize = 200L, bootstrap = 5L, loss = "binomial", verbose = FALSE, lr = 0.15, epochs = 300)
summary(model, n_permute = 10L)
Summary of Deep Neural Network Model
── Feature Importance  
         Importance Std.Err Z value Pr(>|z|)   
bio1 →        2.435   1.464    1.66   0.0962 . 
bio2 →        0.779   0.606    1.28   0.1989   
bio3 →        0.557   0.181    3.08   0.0021 **
bio4 →        1.740   1.308    1.33   0.1834   
bio5 →        0.646   0.416    1.55   0.1200   
bio6 →        0.462   0.246    1.88   0.0604 . 
bio7 →        0.500   0.635    0.79   0.4315   
bio8 →        1.291   1.200    1.08   0.2817   
bio9 →        4.212   2.681    1.57   0.1162   
bio10 →       0.476   0.281    1.69   0.0905 . 
bio11 →       0.676   0.300    2.26   0.0241 * 
bio12 →       1.687   1.417    1.19   0.2336   
bio13 →       1.255   1.924    0.65   0.5143   
bio14 →       1.936   0.898    2.16   0.0311 * 
bio15 →       0.665   0.329    2.02   0.0436 * 
bio16 →       2.865   4.647    0.62   0.5375   
bio17 →       0.918   0.833    1.10   0.2708   
bio18 →       1.437   0.495    2.90   0.0037 **
bio19 →       0.194   0.160    1.21   0.2258   
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
── Average Conditional Effects 
              ACE  Std.Err Z value Pr(>|z|)    
bio1 →    0.14675  0.04333    3.39  0.00071 ***
bio2 →   -0.06460  0.04505   -1.43  0.15161    
bio3 →    0.00939  0.05942    0.16  0.87449    
bio4 →    0.05846  0.04832    1.21  0.22632    
bio5 →    0.03613  0.07135    0.51  0.61262    
bio6 →    0.05161  0.05003    1.03  0.30229    
bio7 →    0.00663  0.05997    0.11  0.91196    
bio8 →   -0.04220  0.06996   -0.60  0.54634    
bio9 →   -0.11845  0.05833   -2.03  0.04230 *  
bio10 →  -0.01999  0.04003   -0.50  0.61758    
bio11 →  -0.01264  0.03791   -0.33  0.73893    
bio12 →  -0.09244  0.07697   -1.20  0.22979    
bio13 →   0.05657  0.09094    0.62  0.53395    
bio14 →   0.14664  0.06821    2.15  0.03156 *  
bio15 →  -0.02688  0.03485   -0.77  0.44053    
bio16 →  -0.10496  0.09431   -1.11  0.26571    
bio17 →   0.01204  0.06987    0.17  0.86322    
bio18 →   0.04106  0.04151    0.99  0.32255    
bio19 →   0.00801  0.03019    0.27  0.79075    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
── Standard Deviation of Conditional Effects  
            ACE Std.Err Z value Pr(>|z|)    
bio1 →   0.2382  0.0614    3.88  0.00010 ***
bio2 →   0.1128  0.0663    1.70  0.08883 .  
bio3 →   0.0975  0.0233    4.18  3.0e-05 ***
bio4 →   0.1572  0.0480    3.27  0.00107 ** 
bio5 →   0.1094  0.0413    2.65  0.00808 ** 
bio6 →   0.1011  0.0356    2.84  0.00458 ** 
bio7 →   0.0818  0.0449    1.82  0.06811 .  
bio8 →   0.1453  0.0656    2.22  0.02672 *  
bio9 →   0.2435  0.0710    3.43  0.00061 ***
bio10 →  0.1089  0.0261    4.17  3.0e-05 ***
bio11 →  0.1215  0.0308    3.94  8.0e-05 ***
bio12 →  0.1776  0.0697    2.55  0.01082 *  
bio13 →  0.1487  0.1304    1.14  0.25399    
bio14 →  0.2339  0.0606    3.86  0.00011 ***
bio15 →  0.0919  0.0291    3.16  0.00158 ** 
bio16 →  0.1807  0.1411    1.28  0.20038    
bio17 →  0.1170  0.0539    2.17  0.02986 *  
bio18 →  0.2286  0.0439    5.21  1.9e-07 ***
bio19 →  0.0544  0.0226    2.41  0.01595 *  
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
PDP(model)
ALE(model)
Question

Use the Titanic_ml dataset and fit a random forest, dnn or a BRT using xgboost. Explore / interpret the fitted model using iml (see also the book: https://christophm.github.io/interpretable-ml-book/).

Tip:

If you use iml, you need to provide a proper prediction function wrapper:

# random Forest (ranger), regression:
predict_wrapper = function(model, newdata) predict(model, data=newdata)$predictions

# random Forest (ranger), classification:
predict_wrapper = function(model, newdata) predict(model, data=newdata)$predictions[,2]

# xgboost:
predict_wrapper = function(model, newdata) predict(model, as.matrix(newdata))

Prepare the data

library(EcoData)
library(dplyr)
library(missRanger) # for imputation


data = titanic_ml

# feature selection
data = data %>% select(survived, sex, age, fare, pclass)# play around with the features

# imputation - remove response variable!
head(data)
     survived    sex  age   fare pclass
561         1 female 30.0 13.000      2
321         1   male   NA 35.500      1
1177        0   male   NA 69.550      3
1098        0   male  6.0 21.075      3
1252        0   male 30.5  8.050      3
1170        0   male 38.5  7.250      3
data_imputed = data
data_imputed[,-1] = missRanger(data_imputed[,-1])

Missing value imputation by random forests

  Variables to impute:      age, fare
  Variables used to impute: sex, age, fare, pclass

iter 1

  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |======================================================================| 100%
iter 2

  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |======================================================================| 100%
iter 3

  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |======================================================================| 100%
iter 4

  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |======================================================================| 100%
iter 5

  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |======================================================================| 100%
summary(data_imputed)
    survived          sex           age               fare        
 Min.   :0.0000   female:466   Min.   : 0.1667   Min.   :  0.000  
 1st Qu.:0.0000   male  :843   1st Qu.:22.0000   1st Qu.:  7.896  
 Median :0.0000                Median :27.7958   Median : 14.454  
 Mean   :0.3853                Mean   :29.5833   Mean   : 33.281  
 3rd Qu.:1.0000                3rd Qu.:36.0000   3rd Qu.: 31.275  
 Max.   :1.0000                Max.   :80.0000   Max.   :512.329  
 NA's   :655                                                      
     pclass     
 Min.   :1.000  
 1st Qu.:2.000  
 Median :3.000  
 Mean   :2.295  
 3rd Qu.:3.000  
 Max.   :3.000  
                
data_imputed = data_imputed %>% 
  mutate(age = (age - mean(age))/sd(age), fare = (fare - mean(fare))/sd(fare),
         sex = as.integer(sex), pclass = as.integer(pclass))

data_obs = data_imputed[!is.na(data_imputed$survived), ]
data_new = data_imputed[is.na(data_imputed$survived), ]
library(ranger)
library("iml")
set.seed(1234)
data_obs$survived = as.factor(data_obs$survived)

rf = ranger(survived ~ ., data = data_obs, importance = "impurity", probability = TRUE)

# For submission:
#write.csv(data.frame(y=predict(rf, data_new)$predictions[,2]), file = "wine_RF.csv")

# Standard depiction of importance:
ranger::importance(rf)
     sex      age     fare   pclass 
73.14943 48.54354 58.23967 22.98536 
# Setup wrapper
predict_wrapper = function(model, newdata) predict(model, data=newdata)$predictions[,2]


# IML:
predictor = Predictor$new(
    rf, data = data_obs[,which(names(data_obs) != "survived")], y = as.integer(data_obs$survived)-1,
    predict.function = predict_wrapper
    )

# Mind: This is stochastical!
importance = FeatureImp$new(predictor, loss = "logLoss")

plot(importance)

# Comparison between standard importance and IML importance:
importanceRf = names(rf$variable.importance)[order(rf$variable.importance, decreasing = TRUE)]
importanceIML = importance$results[1]
comparison = cbind(importanceIML, importanceRf)
colnames(comparison) = c("IML", "RF")
as.matrix(comparison)
     IML      RF      
[1,] "sex"    "sex"   
[2,] "pclass" "fare"  
[3,] "fare"   "age"   
[4,] "age"    "pclass"

Mind that feature importance, and the random forest’s variable importance are related but not equal! Variable importance is a measure for determining importance while creating the forest (i.e. for fitting). Feature importance is a measure for how important a variable is for prediction.

Maybe you want to see other explanation methods as well. Surely you can use the other techniques of this section on your own.

library(xgboost)
library("iml")
set.seed(1234)


data_xg = xgb.DMatrix(
  data = as.matrix(data_obs[,which(names(data_obs) != "survived")]),
  label = as.integer(data_obs$survived)-1
)
brt = xgboost(data_xg, nrounds = 24, objective = "reg:logistic")
[1] train-rmse:0.436392 
[2] train-rmse:0.397434 
[3] train-rmse:0.371319 
[4] train-rmse:0.353402 
[5] train-rmse:0.338994 
[6] train-rmse:0.331816 
[7] train-rmse:0.325260 
[8] train-rmse:0.320678 
[9] train-rmse:0.315598 
[10]    train-rmse:0.311736 
[11]    train-rmse:0.307083 
[12]    train-rmse:0.301911 
[13]    train-rmse:0.299988 
[14]    train-rmse:0.297729 
[15]    train-rmse:0.296160 
[16]    train-rmse:0.293888 
[17]    train-rmse:0.291231 
[18]    train-rmse:0.289166 
[19]    train-rmse:0.285938 
[20]    train-rmse:0.283317 
[21]    train-rmse:0.278259 
[22]    train-rmse:0.274736 
[23]    train-rmse:0.272635 
[24]    train-rmse:0.271098 
# For submission:
#write.csv(round(predict(brt, data_new)), file = "wine_RF.csv")

# Standard depiction of importance:
xgboost::xgb.importance(model = brt)
   Feature      Gain      Cover  Frequency
    <char>     <num>      <num>      <num>
1:     sex 0.3559182 0.12948531 0.03904555
2:    fare 0.2925936 0.39055610 0.48806941
3:     age 0.2018321 0.38698694 0.39696312
4:  pclass 0.1496561 0.09297165 0.07592191
# Setup wrapper
predict_wrapper = function(model, newdata) predict(model, as.matrix(newdata))


# IML:
predictor = Predictor$new(
    brt, data = data_obs[,which(names(data_obs) != "survived")], y = as.integer(data_obs$survived)-1,
    predict.function = predict_wrapper
    )

# Mind: This is stochastical!
importance = FeatureImp$new(predictor, loss = "logLoss")

plot(importance)

library(cito)
data_obs$survived = as.integer(data_obs$survived) - 1
nn = dnn(survived~., data = data_obs, loss = "binomial", lr= 0.03, epochs = 300)
Loss at epoch 1: 0.680140, lr: 0.03000

Loss at epoch 2: 0.627184, lr: 0.03000
Loss at epoch 3: 0.614091, lr: 0.03000
Loss at epoch 4: 0.602799, lr: 0.03000
Loss at epoch 5: 0.596870, lr: 0.03000
Loss at epoch 6: 0.603012, lr: 0.03000
Loss at epoch 7: 0.582219, lr: 0.03000
Loss at epoch 8: 0.585519, lr: 0.03000
Loss at epoch 9: 0.581047, lr: 0.03000
Loss at epoch 10: 0.577028, lr: 0.03000
Loss at epoch 11: 0.563401, lr: 0.03000
Loss at epoch 12: 0.595628, lr: 0.03000
Loss at epoch 13: 0.543731, lr: 0.03000
Loss at epoch 14: 0.559381, lr: 0.03000
Loss at epoch 15: 0.552078, lr: 0.03000
Loss at epoch 16: 0.569328, lr: 0.03000
Loss at epoch 17: 0.555280, lr: 0.03000
Loss at epoch 18: 0.518566, lr: 0.03000
Loss at epoch 19: 0.543283, lr: 0.03000
Loss at epoch 20: 0.555313, lr: 0.03000
Loss at epoch 21: 0.539601, lr: 0.03000
Loss at epoch 22: 0.521306, lr: 0.03000
Loss at epoch 23: 0.515184, lr: 0.03000
Loss at epoch 24: 0.518902, lr: 0.03000
Loss at epoch 25: 0.504003, lr: 0.03000
Loss at epoch 26: 0.509960, lr: 0.03000
Loss at epoch 27: 0.536239, lr: 0.03000
Loss at epoch 28: 0.501826, lr: 0.03000
Loss at epoch 29: 0.500220, lr: 0.03000
Loss at epoch 30: 0.491137, lr: 0.03000
Loss at epoch 31: 0.506785, lr: 0.03000
Loss at epoch 32: 0.502697, lr: 0.03000
Loss at epoch 33: 0.498212, lr: 0.03000
Loss at epoch 34: 0.482820, lr: 0.03000
Loss at epoch 35: 0.495973, lr: 0.03000
Loss at epoch 36: 0.534453, lr: 0.03000
Loss at epoch 37: 0.541357, lr: 0.03000
Loss at epoch 38: 0.508478, lr: 0.03000
Loss at epoch 39: 0.481845, lr: 0.03000
Loss at epoch 40: 0.498684, lr: 0.03000
Loss at epoch 41: 0.505123, lr: 0.03000
Loss at epoch 42: 0.474501, lr: 0.03000
Loss at epoch 43: 0.466548, lr: 0.03000
Loss at epoch 44: 0.476093, lr: 0.03000
Loss at epoch 45: 0.489739, lr: 0.03000
Loss at epoch 46: 0.465223, lr: 0.03000
Loss at epoch 47: 0.504016, lr: 0.03000
Loss at epoch 48: 0.460403, lr: 0.03000
Loss at epoch 49: 0.467457, lr: 0.03000
Loss at epoch 50: 0.518689, lr: 0.03000
Loss at epoch 51: 0.454340, lr: 0.03000
Loss at epoch 52: 0.466856, lr: 0.03000
Loss at epoch 53: 0.527553, lr: 0.03000
Loss at epoch 54: 0.471639, lr: 0.03000
Loss at epoch 55: 0.530014, lr: 0.03000
Loss at epoch 56: 0.505043, lr: 0.03000
Loss at epoch 57: 0.477971, lr: 0.03000
Loss at epoch 58: 0.474161, lr: 0.03000
Loss at epoch 59: 0.478012, lr: 0.03000
Loss at epoch 60: 0.495989, lr: 0.03000
Loss at epoch 61: 0.467095, lr: 0.03000
Loss at epoch 62: 0.458001, lr: 0.03000
Loss at epoch 63: 0.458551, lr: 0.03000
Loss at epoch 64: 0.493782, lr: 0.03000
Loss at epoch 65: 0.474038, lr: 0.03000
Loss at epoch 66: 0.469678, lr: 0.03000
Loss at epoch 67: 0.480600, lr: 0.03000
Loss at epoch 68: 0.501369, lr: 0.03000
Loss at epoch 69: 0.473502, lr: 0.03000
Loss at epoch 70: 0.471554, lr: 0.03000
Loss at epoch 71: 0.469595, lr: 0.03000
Loss at epoch 72: 0.480244, lr: 0.03000
Loss at epoch 73: 0.471141, lr: 0.03000
Loss at epoch 74: 0.492347, lr: 0.03000
Loss at epoch 75: 0.547993, lr: 0.03000
Loss at epoch 76: 0.480171, lr: 0.03000
Loss at epoch 77: 0.460500, lr: 0.03000
Loss at epoch 78: 0.494222, lr: 0.03000
Loss at epoch 79: 0.460273, lr: 0.03000
Loss at epoch 80: 0.462849, lr: 0.03000
Loss at epoch 81: 0.510636, lr: 0.03000
Loss at epoch 82: 0.467277, lr: 0.03000
Loss at epoch 83: 0.454265, lr: 0.03000
Loss at epoch 84: 0.515276, lr: 0.03000
Loss at epoch 85: 0.524213, lr: 0.03000
Loss at epoch 86: 0.525216, lr: 0.03000
Loss at epoch 87: 0.476678, lr: 0.03000
Loss at epoch 88: 0.450505, lr: 0.03000
Loss at epoch 89: 0.487568, lr: 0.03000
Loss at epoch 90: 0.443901, lr: 0.03000
Loss at epoch 91: 0.449064, lr: 0.03000
Loss at epoch 92: 0.538493, lr: 0.03000
Loss at epoch 93: 0.456148, lr: 0.03000
Loss at epoch 94: 0.462342, lr: 0.03000
Loss at epoch 95: 0.460104, lr: 0.03000
Loss at epoch 96: 0.440028, lr: 0.03000
Loss at epoch 97: 0.450318, lr: 0.03000
Loss at epoch 98: 0.485396, lr: 0.03000
Loss at epoch 99: 0.497702, lr: 0.03000
Loss at epoch 100: 0.462040, lr: 0.03000
Loss at epoch 101: 0.470966, lr: 0.03000
Loss at epoch 102: 0.532416, lr: 0.03000
Loss at epoch 103: 0.466734, lr: 0.03000
Loss at epoch 104: 0.472258, lr: 0.03000
Loss at epoch 105: 0.449124, lr: 0.03000
Loss at epoch 106: 0.506932, lr: 0.03000
Loss at epoch 107: 0.457045, lr: 0.03000
Loss at epoch 108: 0.487150, lr: 0.03000
Loss at epoch 109: 0.508638, lr: 0.03000
Loss at epoch 110: 0.447389, lr: 0.03000
Loss at epoch 111: 0.466366, lr: 0.03000
Loss at epoch 112: 0.475201, lr: 0.03000
Loss at epoch 113: 0.487970, lr: 0.03000
Loss at epoch 114: 0.467867, lr: 0.03000
Loss at epoch 115: 0.457009, lr: 0.03000
Loss at epoch 116: 0.487981, lr: 0.03000
Loss at epoch 117: 0.455064, lr: 0.03000
Loss at epoch 118: 0.451866, lr: 0.03000
Loss at epoch 119: 0.469710, lr: 0.03000
Loss at epoch 120: 0.488751, lr: 0.03000
Loss at epoch 121: 0.527614, lr: 0.03000
Loss at epoch 122: 0.455359, lr: 0.03000
Loss at epoch 123: 0.523272, lr: 0.03000
Loss at epoch 124: 0.540604, lr: 0.03000
Loss at epoch 125: 0.454363, lr: 0.03000
Loss at epoch 126: 0.448797, lr: 0.03000
Loss at epoch 127: 0.486549, lr: 0.03000
Loss at epoch 128: 0.476954, lr: 0.03000
Loss at epoch 129: 0.445258, lr: 0.03000
Loss at epoch 130: 0.470452, lr: 0.03000
Loss at epoch 131: 0.452872, lr: 0.03000
Loss at epoch 132: 0.481950, lr: 0.03000
Loss at epoch 133: 0.442628, lr: 0.03000
Loss at epoch 134: 0.467391, lr: 0.03000
Loss at epoch 135: 0.462168, lr: 0.03000
Loss at epoch 136: 0.467282, lr: 0.03000
Loss at epoch 137: 0.494596, lr: 0.03000
Loss at epoch 138: 0.519634, lr: 0.03000
Loss at epoch 139: 0.451834, lr: 0.03000
Loss at epoch 140: 0.468225, lr: 0.03000
Loss at epoch 141: 0.469203, lr: 0.03000
Loss at epoch 142: 0.442753, lr: 0.03000
Loss at epoch 143: 0.514335, lr: 0.03000
Loss at epoch 144: 0.477644, lr: 0.03000
Loss at epoch 145: 0.469927, lr: 0.03000
Loss at epoch 146: 0.478519, lr: 0.03000
Loss at epoch 147: 0.467338, lr: 0.03000
Loss at epoch 148: 0.493336, lr: 0.03000
Loss at epoch 149: 0.492499, lr: 0.03000
Loss at epoch 150: 0.452096, lr: 0.03000
Loss at epoch 151: 0.465297, lr: 0.03000
Loss at epoch 152: 0.456905, lr: 0.03000
Loss at epoch 153: 0.439069, lr: 0.03000
Loss at epoch 154: 0.463795, lr: 0.03000
Loss at epoch 155: 0.466537, lr: 0.03000
Loss at epoch 156: 0.436806, lr: 0.03000
Loss at epoch 157: 0.477351, lr: 0.03000
Loss at epoch 158: 0.483862, lr: 0.03000
Loss at epoch 159: 0.449230, lr: 0.03000
Loss at epoch 160: 0.436911, lr: 0.03000
Loss at epoch 161: 0.480520, lr: 0.03000
Loss at epoch 162: 0.448325, lr: 0.03000
Loss at epoch 163: 0.472405, lr: 0.03000
Loss at epoch 164: 0.469564, lr: 0.03000
Loss at epoch 165: 0.478897, lr: 0.03000
Loss at epoch 166: 0.430507, lr: 0.03000
Loss at epoch 167: 0.452615, lr: 0.03000
Loss at epoch 168: 0.449949, lr: 0.03000
Loss at epoch 169: 0.489078, lr: 0.03000
Loss at epoch 170: 0.480937, lr: 0.03000
Loss at epoch 171: 0.462101, lr: 0.03000
Loss at epoch 172: 0.446582, lr: 0.03000
Loss at epoch 173: 0.441203, lr: 0.03000
Loss at epoch 174: 0.475552, lr: 0.03000
Loss at epoch 175: 0.445585, lr: 0.03000
Loss at epoch 176: 0.437739, lr: 0.03000
Loss at epoch 177: 0.456027, lr: 0.03000
Loss at epoch 178: 0.482569, lr: 0.03000
Loss at epoch 179: 0.480489, lr: 0.03000
Loss at epoch 180: 0.479673, lr: 0.03000
Loss at epoch 181: 0.463853, lr: 0.03000
Loss at epoch 182: 0.458469, lr: 0.03000
Loss at epoch 183: 0.475513, lr: 0.03000
Loss at epoch 184: 0.435408, lr: 0.03000
Loss at epoch 185: 0.442120, lr: 0.03000
Loss at epoch 186: 0.456404, lr: 0.03000
Loss at epoch 187: 0.464689, lr: 0.03000
Loss at epoch 188: 0.439348, lr: 0.03000
Loss at epoch 189: 0.460963, lr: 0.03000
Loss at epoch 190: 0.467793, lr: 0.03000
Loss at epoch 191: 0.549054, lr: 0.03000
Loss at epoch 192: 0.467256, lr: 0.03000
Loss at epoch 193: 0.463005, lr: 0.03000
Loss at epoch 194: 0.458582, lr: 0.03000
Loss at epoch 195: 0.449975, lr: 0.03000
Loss at epoch 196: 0.469950, lr: 0.03000
Loss at epoch 197: 0.444055, lr: 0.03000
Loss at epoch 198: 0.492405, lr: 0.03000
Loss at epoch 199: 0.488203, lr: 0.03000
Loss at epoch 200: 0.443248, lr: 0.03000
Loss at epoch 201: 0.441677, lr: 0.03000
Loss at epoch 202: 0.457636, lr: 0.03000
Loss at epoch 203: 0.444725, lr: 0.03000
Loss at epoch 204: 0.441414, lr: 0.03000
Loss at epoch 205: 0.460864, lr: 0.03000
Loss at epoch 206: 0.436202, lr: 0.03000
Loss at epoch 207: 0.435698, lr: 0.03000
Loss at epoch 208: 0.458819, lr: 0.03000
Loss at epoch 209: 0.471494, lr: 0.03000
Loss at epoch 210: 0.446610, lr: 0.03000
Loss at epoch 211: 0.446449, lr: 0.03000
Loss at epoch 212: 0.508342, lr: 0.03000
Loss at epoch 213: 0.482333, lr: 0.03000
Loss at epoch 214: 0.496829, lr: 0.03000
Loss at epoch 215: 0.469062, lr: 0.03000
Loss at epoch 216: 0.480739, lr: 0.03000
Loss at epoch 217: 0.455880, lr: 0.03000
Loss at epoch 218: 0.504631, lr: 0.03000
Loss at epoch 219: 0.491657, lr: 0.03000
Loss at epoch 220: 0.505383, lr: 0.03000
Loss at epoch 221: 0.440286, lr: 0.03000
Loss at epoch 222: 0.498957, lr: 0.03000
Loss at epoch 223: 0.448307, lr: 0.03000
Loss at epoch 224: 0.442937, lr: 0.03000
Loss at epoch 225: 0.486785, lr: 0.03000
Loss at epoch 226: 0.462034, lr: 0.03000
Loss at epoch 227: 0.452252, lr: 0.03000
Loss at epoch 228: 0.474529, lr: 0.03000
Loss at epoch 229: 0.497400, lr: 0.03000
Loss at epoch 230: 0.459944, lr: 0.03000
Loss at epoch 231: 0.458142, lr: 0.03000
Loss at epoch 232: 0.449935, lr: 0.03000
Loss at epoch 233: 0.445906, lr: 0.03000
Loss at epoch 234: 0.443248, lr: 0.03000
Loss at epoch 235: 0.479681, lr: 0.03000
Loss at epoch 236: 0.442719, lr: 0.03000
Loss at epoch 237: 0.459170, lr: 0.03000
Loss at epoch 238: 0.448564, lr: 0.03000
Loss at epoch 239: 0.447090, lr: 0.03000
Loss at epoch 240: 0.452383, lr: 0.03000
Loss at epoch 241: 0.471106, lr: 0.03000
Loss at epoch 242: 0.464091, lr: 0.03000
Loss at epoch 243: 0.445294, lr: 0.03000
Loss at epoch 244: 0.454009, lr: 0.03000
Loss at epoch 245: 0.472256, lr: 0.03000
Loss at epoch 246: 0.443102, lr: 0.03000
Loss at epoch 247: 0.497823, lr: 0.03000
Loss at epoch 248: 0.468261, lr: 0.03000
Loss at epoch 249: 0.436923, lr: 0.03000
Loss at epoch 250: 0.445430, lr: 0.03000
Loss at epoch 251: 0.449998, lr: 0.03000
Loss at epoch 252: 0.454277, lr: 0.03000
Loss at epoch 253: 0.467098, lr: 0.03000
Loss at epoch 254: 0.462733, lr: 0.03000
Loss at epoch 255: 0.519709, lr: 0.03000
Loss at epoch 256: 0.513252, lr: 0.03000
Loss at epoch 257: 0.429080, lr: 0.03000
Loss at epoch 258: 0.475779, lr: 0.03000
Loss at epoch 259: 0.557900, lr: 0.03000
Loss at epoch 260: 0.464732, lr: 0.03000
Loss at epoch 261: 0.478027, lr: 0.03000
Loss at epoch 262: 0.443418, lr: 0.03000
Loss at epoch 263: 0.443055, lr: 0.03000
Loss at epoch 264: 0.489595, lr: 0.03000
Loss at epoch 265: 0.431555, lr: 0.03000
Loss at epoch 266: 0.454761, lr: 0.03000
Loss at epoch 267: 0.513683, lr: 0.03000
Loss at epoch 268: 0.440240, lr: 0.03000
Loss at epoch 269: 0.446840, lr: 0.03000
Loss at epoch 270: 0.446733, lr: 0.03000
Loss at epoch 271: 0.500592, lr: 0.03000
Loss at epoch 272: 0.445129, lr: 0.03000
Loss at epoch 273: 0.430155, lr: 0.03000
Loss at epoch 274: 0.496784, lr: 0.03000
Loss at epoch 275: 0.466872, lr: 0.03000
Loss at epoch 276: 0.460518, lr: 0.03000
Loss at epoch 277: 0.442390, lr: 0.03000
Loss at epoch 278: 0.535002, lr: 0.03000
Loss at epoch 279: 0.488710, lr: 0.03000
Loss at epoch 280: 0.441545, lr: 0.03000
Loss at epoch 281: 0.439586, lr: 0.03000
Loss at epoch 282: 0.436646, lr: 0.03000
Loss at epoch 283: 0.452601, lr: 0.03000
Loss at epoch 284: 0.442023, lr: 0.03000
Loss at epoch 285: 0.455307, lr: 0.03000
Loss at epoch 286: 0.477246, lr: 0.03000
Loss at epoch 287: 0.485411, lr: 0.03000
Loss at epoch 288: 0.435920, lr: 0.03000
Loss at epoch 289: 0.462004, lr: 0.03000
Loss at epoch 290: 0.466949, lr: 0.03000
Loss at epoch 291: 0.436228, lr: 0.03000
Loss at epoch 292: 0.435107, lr: 0.03000
Loss at epoch 293: 0.479969, lr: 0.03000
Loss at epoch 294: 0.458653, lr: 0.03000
Loss at epoch 295: 0.442827, lr: 0.03000
Loss at epoch 296: 0.482269, lr: 0.03000
Loss at epoch 297: 0.431427, lr: 0.03000
Loss at epoch 298: 0.471515, lr: 0.03000
Loss at epoch 299: 0.492538, lr: 0.03000
Loss at epoch 300: 0.451175, lr: 0.03000
summary(nn)
Summary of Deep Neural Network Model

Feature Importance:
  variable importance_1
1      sex     1.657285
2      age     1.161388
3     fare     1.030452
4   pclass     1.541186

Average Conditional Effects:
        Response_1
sex    -0.40100236
age    -0.10234583
fare   -0.06016106
pclass -0.20588012

Standard Deviation of Conditional Effects:
       Response_1
sex    0.20862693
age    0.06573672
fare   0.07341272
pclass 0.12912456